Plastic Learning with Deep Fourier Features

Alex Lewandowski       Dale Schuurmans†‡⋆      Marlos C. Machado†⋆
Department of Computing Science, University of Alberta, Google DeepMind,
Canada CIFAR AI Chair
Abstract
footnotetext: Correspondence to: Alex Lewandowski <[email protected]>.

Deep neural networks can struggle to learn continually in the face of non-stationarity. This phenomenon is known as loss of plasticity. In this paper, we identify underlying principles that lead to plastic algorithms. In particular, we provide theoretical results showing that linear function approximation, as well as a special case of deep linear networks, do not suffer from loss of plasticity. We then propose deep Fourier features, which are the concatenation of a sine and cosine in every layer, and we show that this combination provides a dynamic balance between the trainability obtained through linearity and the effectiveness obtained through the nonlinearity of neural networks. Deep networks composed entirely of deep Fourier features are highly trainable and sustain their trainability over the course of learning. Our empirical results show that continual learning performance can be drastically improved by replacing ReLU activations with deep Fourier features. These results hold for different continual learning scenarios (e.g., label noise, class incremental learning, pixel permutations) on all major supervised learning datasets used for continual learning research, such as CIFAR10, CIFAR100, and tiny-ImageNet.

1 Introduction

Continual learning is a problem setting that moves past some of the rigid assumptions found in supervised, semi-supervised, and unsupervised learning (Ring,, 1994; Thrun,, 1998). In particular, the continual learning setting involves learning from data sampled from a changing, non-stationary distribution rather than from a fixed distribution. A performant continual learning algorithm faces a trade-off due to its limited capacity: it should avoid forgetting what was previously learned while also being able to adapt to new incoming data, an ability known as plasticity (Parisi et al.,, 2019). Current approaches that use neural networks for continual learning are not yet capable of making this trade-off due to catastrophic forgetting (Kirkpatrick et al.,, 2017) and loss of plasticity (Dohare et al.,, 2021; Lyle et al.,, 2023; Dohare et al.,, 2024). The training of neural networks is in fact an active research area in the theory literature for supervised learning (Jacot et al.,, 2018; Yang et al.,, 2023; Kunin et al.,, 2024), which suggests there is much left to be understood in training neural networks continually. Compared to the relatively well-understood problem setting of supervised learning, even the formalization of the continual learning problem is an active research area (Kumar et al., 2023a, ; Abel et al.,, 2024; Liu et al.,, 2023). With these uncertainties surrounding current practice, we take a step back to better understand the inductive biases used to build algorithms for continual learning.

One fundamental capability expected from a continual learning algorithm is its sustained ability to update its predictions on new data. Recent work has identified the phenomenon of loss of plasticity in neural networks in which stochastic gradient-based training becomes less effective when faced with data from a changing, non-stationary distribution (Dohare et al.,, 2024). Several methods have been proposed to address the loss of plasticity in neural networks, with their success demonstrated empirically across both supervised and reinforcement learning (Ash and Adams,, 2020; Lyle et al.,, 2022; 2023; Lee et al.,, 2024). Empirically, works have identified that the plasticity of neural networks is sensitive to different components of the training process, such as the activation function (Abbas et al.,, 2023). However, little is known about what is required for learning with sustained plasticity.

The goal of this paper is to identify a basic continual learning algorithm that does not lose plasticity in both theory and practice rather than mitigating the loss of plasticity in existing neural network architectures. In particular, we investigate the effect of the nonlinearity of neural networks on the loss of plasticity. While loss of plasticity is a well-documented phenomenon in neural networks, previous empirical observations suggest that linear function approximation is capable of learning continually without suffering from loss of plasticity (Dohare et al.,, 2021; 2024). In this paper, we prove that linear function approximation does not suffer from loss of plasticity and can sustain their learning ability on a sequence of tasks. We then extend our analysis to a special case of deep linear networks, which provide an interesting intermediate case between deep nonlinear networks and linear function approximation. This is because deep linear networks are linear in representation but nonlinear in gradient dynamics (Saxe et al.,, 2014). We provide theoretical and empirical evidence that general deep linear networks also do not suffer from loss of plasticity. The plasticity of deep linear networks is surprising because it suggests that, for sustaining plasticity, the nonlinear dynamics of deep linear networks are more similar to the linear dynamics of linear function approximation than they are to the nonlinear dynamics of deep nonlinear networks.

x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT x2subscript𝑥2x_{2}italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT x3subscript𝑥3x_{3}italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT x4subscript𝑥4x_{4}italic_x start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT z2subscript𝑧2z_{2}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
Refer to caption
Refer to caption
Figure 1: A neural network with deep Fourier features in every layer approximately embeds a deep linear network. A single layer using deep Fourier features linearly combines the inputs, x𝑥xitalic_x, to compute the pre-activations, z𝑧zitalic_z, and each pre-activation is mapped to both a cos unit and a sin unit (Left). For each pre-activation, either the sin unit (Middle) or the cos unit (Right) is well-approximated by a linear function.

Given this seemingly natural advantage of linearity for continual learning, as well as its inherent limitation to learning only linear representations, we explore how nonlinear networks can better emulate the dynamics of deep linear networks to sustain plasticity. We hypothesize that, to effectively learn continually, the neural network must balance between introducing too much linearity and suffering from loss of deep representations and introducing too much nonlinearity and suffering from loss of plasticity. In fact, we show that previous work partially satisfies this hypothesis, such as the concatenated ReLU (Shang et al.,, 2016), leaky-ReLU activations (Xu et al.,, 2015), and residual connections (He et al.,, 2016), but they fail at striking this balance. Our results build on previous work that identified issues of unit saturation (Abbas et al.,, 2023) and unit linearization (Lyle et al.,, 2024) as issues in continually training neural networks with common activation functions. In particular, we generalize these phenomena to unit sign entropy. We show that linear networks have high unit sign entropy, meaning that the sign of a hidden unit on different inputs is positive on approximately half the inputs. In contrast, deep nonlinear networks with most activation functions tend to have low unit sign entropy, which indicates saturation or linearization.

Periodic activation functions (Parascandolo et al.,, 2017), like the sinusoid function (sin), are a notable exception for having high unit sign entropy despite still suffering from loss of plasticity. Thus, in addition to unit sign entropy, we demonstrate that the network’s activation function should be well-approximated by a linear function. We propose deep Fourier features as a means of approximating linearity dynamically, with every pre-activation being connected to two units, one of which will always be well-approximated by a linear function. In particular, deep Fourier features concatenate a sine and a cosine activation in each hidden layer. The resulting network is nonlinear while also approximately embedding a deep linear network using all of its parameters. Deep Fourier features differ from previous approaches that use Fourier features only in the input layer (Tancik et al.,, 2020; Li and Pathak,, 2021; Yang et al.,, 2022) or that use fixed Fourier feature basis (Rahimi and Recht,, 2007; Konidaris et al.,, 2011). We demonstrate that networks using these shallow Fourier features still exhibit a loss of plasticity. Only by using deep Fourier features in every layer is the network capable of sustaining and improving trainability in a continual learning setting. Using tiny-ImageNet (Le and Yang,, 2015), CIFAR10, and CIFAR100 (Krizhevsky,, 2009), we show that deep Fourier features can be used as a drop-in replacement for improving trainability in commonly used neural network architectures. Furthermore, deep Fourier features achieve superior generalization performance when combined with regularization because their trainability allows for much higher regularization strengths.

2 Problem Setting

We define a deep network, fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT with a a parameter set, θ={𝐖l,𝐛l}l=1L𝜃superscriptsubscriptsubscript𝐖𝑙subscript𝐛𝑙𝑙1𝐿\theta=\{\mathbf{W}_{l},\mathbf{b}_{l}\}_{l=1}^{L}italic_θ = { bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, as a sequence of layers, in which each layer applies a linear transformation followed by an element-wise activation function, ϕitalic-ϕ\phiitalic_ϕ in each hidden layer. The output of the network, fθ(x):=hL(x)assignsubscript𝑓𝜃𝑥subscript𝐿𝑥f_{\theta}(x):=h_{L}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) := italic_h start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ( italic_x ), is defined recursively by hl=[hl,1,,hl,w]=[ϕ(zl,1),,ϕ(zl,w)]=ϕ(zl)subscript𝑙subscript𝑙1subscript𝑙𝑤italic-ϕsubscript𝑧𝑙1italic-ϕsubscript𝑧𝑙𝑤italic-ϕsubscript𝑧𝑙h_{l}=[h_{l,1},\dotso,h_{l,w}]=[\phi(z_{l,1}),\dotso,\phi(z_{l,w})]=\phi(z_{l})italic_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = [ italic_h start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_l , italic_w end_POSTSUBSCRIPT ] = [ italic_ϕ ( italic_z start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT ) , … , italic_ϕ ( italic_z start_POSTSUBSCRIPT italic_l , italic_w end_POSTSUBSCRIPT ) ] = italic_ϕ ( italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ), and, zl=𝐖lhl1+𝐛lsubscript𝑧𝑙subscript𝐖𝑙subscript𝑙1subscript𝐛𝑙z_{l}=\mathbf{W}_{l}h_{l-1}+\mathbf{b}_{l}italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT + bold_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT where w𝑤witalic_w is the width of the network, and h0=xsubscript0𝑥h_{0}=xitalic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_x. We refer to a particular element of the hidden layer’s output hl,isubscript𝑙𝑖h_{l,i}italic_h start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT as a unit. The deep network is a deep linear network when the activation function is the identity, ϕ(z)=zitalic-ϕ𝑧𝑧\phi(z)=zitalic_ϕ ( italic_z ) = italic_z. Linear function approximation is equivalent to a linear network with L=1𝐿1L=1italic_L = 1.

The problem setting that we consider is continual supervised learning without task boundaries. At each iteration, a minibatch of observation-target pairs of size M𝑀Mitalic_M, {xi,yi}i=1Msuperscriptsubscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖1𝑀\{x_{i},y_{i}\}_{i=1}^{M}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT, is used to update the parameters θ𝜃\thetaitalic_θ of a neural network fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT using a variant of stochastic gradient descent. The learning problem is continual because the distribution from which the data is sampled, p(x,y)𝑝𝑥𝑦p(x,y)italic_p ( italic_x , italic_y ), is changing. For simplicity, we assume this non-stationarity changes the distribution over the input-target pairs every T𝑇Titalic_T iterations. The data is sampled from a single distribution for T𝑇Titalic_T steps, and we refer to this particular temporary stationary problem as a task, τ𝜏\tauitalic_τ. The distribution over observations and targets that defines a task τ𝜏\tauitalic_τ is denoted by pτsubscript𝑝𝜏p_{\tau}italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT.

We focus our theoretical analysis on the problem of loss of trainability, in which we evaluate the neural network at the end of each task using samples from the most recent task distribution, pτsubscript𝑝𝜏p_{\tau}italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT, as is commonly done in previous work (Lyle et al.,, 2023). Loss of trainability refers to the problem where the neural network is unable to sustain its initial performance on the first task to later tasks. Specifically, we denote the optimisation objective by Jτ(θ)=𝔼(x,y)pτ[(fθ(x),y)],subscript𝐽𝜏𝜃subscript𝔼similar-to𝑥𝑦subscript𝑝𝜏delimited-[]subscript𝑓𝜃𝑥𝑦J_{\tau}(\theta)=\mathbb{E}_{(x,y)\sim p_{\tau}}\big{[}\ell(f_{\theta}(x),y)% \big{]},italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ] , for some loss function \ellroman_ℓ, and task-specific data distribution pτsubscript𝑝𝜏p_{\tau}italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT. We use t𝑡titalic_t to denote the iteration count of the learning algorithm, and thus the current task number can be written as τ(t)=t/T𝜏𝑡𝑡𝑇\tau(t)=\lfloor t/T\rflooritalic_τ ( italic_t ) = ⌊ italic_t / italic_T ⌋.

3 Trainability and Linearity

In this section, we show that, unlike nonlinear networks, linear networks do not suffer from loss of trainability. That is, if the number of iterations in each task is sufficiently large, a linear network sustain trainability on every task in the sequence. We then show theoretically that a special case of deep linear networks also does not suffer from loss of trainability, and we empirically validate the theoretical findings in more general settings. These results provide a theoretical basis for previous work that uses a linear baseline in loss of plasticity experiments.

3.1 Trainability of Linear Function Approximation

We first prove that loss of trainability does not occur with linear function approximation, fθ(x)=𝐖lx+𝐛lsubscript𝑓𝜃𝑥subscript𝐖𝑙𝑥subscript𝐛𝑙f_{\theta}(x)=\mathbf{W}_{l}x+\mathbf{b}_{l}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_x + bold_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. We prove this by showing that any sequence of tasks can be learned with a large enough number of iterations per task. In particular, the suboptimality gap on the τ𝜏\tauitalic_τ-th task can be upper bounded on a quantity that is independent of the solution found on the first τ1𝜏1\tau-1italic_τ - 1 tasks. Linear function approximation avoids loss of trainability because the optimisation problem on each task is convex (Agrawal et al.,, 2021; Boyd and Vandenberghe,, 2004), with a unique global optimum, θτsuperscriptsubscript𝜃𝜏\theta_{\tau}^{\star}italic_θ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We now state the theorem, which we prove in Appendix B

Theorem 1.

Let θ(τT)superscript𝜃𝜏𝑇\theta^{(\tau T)}italic_θ start_POSTSUPERSCRIPT ( italic_τ italic_T ) end_POSTSUPERSCRIPT denote the linear weights learned at the end of the τ𝜏\tauitalic_τ-th task, with the corresponding unique global minimum for task τ𝜏\tauitalic_τ being denoted by θτsuperscriptsubscript𝜃𝜏\theta_{\tau}^{\star}italic_θ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Assuming the objective function is μ𝜇\muitalic_μ-strongly convex, the suboptimality gap for gradient descent on the τ𝜏\tauitalic_τ-th task is

Jτ(θ(τT))Jτ(θτ)<2D(1αμ)TαT(1(1αμ)T),subscript𝐽𝜏superscript𝜃𝜏𝑇subscript𝐽𝜏superscriptsubscript𝜃𝜏2𝐷superscript1𝛼𝜇𝑇𝛼𝑇1superscript1𝛼𝜇𝑇J_{\tau}(\theta^{(\tau T)})-J_{\tau}(\theta_{\tau}^{\star})<\frac{2D(1-\alpha% \mu)^{T}}{\alpha T(1-(1-\alpha\mu)^{T})},italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_τ italic_T ) end_POSTSUPERSCRIPT ) - italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) < divide start_ARG 2 italic_D ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T ( 1 - ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) end_ARG ,

where each task lasts for T𝑇Titalic_T iteration, D𝐷Ditalic_D is the assumed bound on the parameters at the global minimum for every task, and α𝛼\alphaitalic_α is the step-size.

In addition to convexity, we assume that the objective function is μ𝜇\muitalic_μ-strongly convex, θ2Jτ(θ)μ𝐈succeedssuperscriptsubscript𝜃2subscript𝐽𝜏𝜃𝜇𝐈\nabla_{\theta}^{2}J_{\tau}(\theta)\succ\mu\mathbf{I}∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) ≻ italic_μ bold_I, where θ2Jτ(θ)superscriptsubscript𝜃2subscript𝐽𝜏𝜃\nabla_{\theta}^{2}J_{\tau}(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) denotes the Hessian. This assumption is often satisfied in the continual learning problem outlined in Section 2 (see Appendix A.1 for more discussion). Lastly, we assume that the parameters at the global optimum for every task are bounded: θτ2<Dsubscriptnormsubscript𝜃𝜏2𝐷\|\theta_{\tau}\|_{2}<D∥ italic_θ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < italic_D. This is true for regression problems if the observations and targets are bounded. In classification tasks, the global optimum can be at infinity because activation functions such as the sigmoid and the softmax are maximized at infinity. In this case, we constrain the parameter set, {θ:θ2<D}conditional-set𝜃subscriptnorm𝜃2𝐷\{\theta:\|\theta\|_{2}<D\}{ italic_θ : ∥ italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < italic_D }, and project the optimum onto this set. Intuitively, this theorem states that if the problem is bounded and effectively strongly convex due to a finite number of iterations, then the optimisation dynamics are well-behaved for every task in the bounded set. In particular, this means that the error on each task can be upper bounded by a quantity independent of the initialization found on previous tasks. Thus, given enough iterations, linear function approximation can learn continually without loss of trainability.

3.2 Trainability of Deep Linear Networks

We now provide evidence that, similar to linear function approximation, deep linear networks also do not suffer from loss of trainability. Deep linear networks differ from deep nonlinear networks by not using nonlinear activation functions in their hidden layers (Bernacchia et al.,, 2018; Ziyin et al.,, 2022). This means that a deep linear network can only represent linear functions. At the same time, its gradient update dynamics are nonlinear and non-convex, similar to deep nonlinear neural networks (Saxe et al.,, 2014). Our central claim here is that deep linear networks under gradient descent dynamics avoid parameter configurations that would lead to loss of trainability.

To simplify notation, without loss of generality, we combine the weights and biases into a single parameter for each layer in the deep linear network , θ={θ1,,θL}𝜃subscript𝜃1subscript𝜃𝐿\theta=\{\theta_{1},\dotso,\theta_{L}\}italic_θ = { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT }, and fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\cdots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x. We denote the product of weight matrices, or simply product matrix, as θ¯=θLθL1θ1¯𝜃subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1\bar{\theta}=\theta_{L}\theta_{L-1}\cdots\theta_{1}over¯ start_ARG italic_θ end_ARG = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, which allows us to write the deep linear network in terms of the product matrix: fθ(x)=θ¯xsubscript𝑓𝜃𝑥¯𝜃𝑥f_{\theta}(x)=\bar{\theta}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = over¯ start_ARG italic_θ end_ARG italic_x. The problem setup we use for the deep linear analysis follows previous work (Huh,, 2020), and we provide additional technical details for optimisation dynamics of deep linear networks in Appendix A.3.

The gradient of the loss function with respect to the parameters of a deep linear network can be written in terms of the gradient with respect to the product matrix θ¯¯𝜃\bar{\theta}over¯ start_ARG italic_θ end_ARG (Bah et al.,, 2022):

θjJ(θ)=θj+1θj+2θLθ¯J(θ¯)θ1θ2θj1,subscriptsubscript𝜃𝑗𝐽𝜃superscriptsubscript𝜃𝑗1topsuperscriptsubscript𝜃𝑗2topsuperscriptsubscript𝜃𝐿topsubscript¯𝜃𝐽¯𝜃superscriptsubscript𝜃1topsuperscriptsubscript𝜃2topsuperscriptsubscript𝜃𝑗1top\nabla_{\theta_{j}}J(\theta)=\theta_{j+1}^{\top}\theta_{j+2}^{\top}\cdots% \theta_{L}^{\top}\nabla_{\bar{\theta}}J(\bar{\theta})\theta_{1}^{\top}\theta_{% 2}^{\top}\cdots\theta_{j-1}^{\top},∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_J ( italic_θ ) = italic_θ start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_j + 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT italic_J ( over¯ start_ARG italic_θ end_ARG ) italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ,

where the term θ¯J(θ¯)subscript¯𝜃𝐽¯𝜃\nabla_{\bar{\theta}}J(\bar{\theta})∇ start_POSTSUBSCRIPT over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT italic_J ( over¯ start_ARG italic_θ end_ARG ) is the gradient of the loss with respect to the product matrix, treating it as if it was linear function approximation. The gradient is nonlinear because of the coupling between the gradient of the parameter at one layer and the value of the parameters of the other layers. Nevertheless, the gradient dynamics of the individual parameters can be combined to yield the dynamics of the product matrix (Arora et al.,, 2018),

¯θJ(θ)=Pθ¯θ¯J(θ¯).subscript¯𝜃𝐽𝜃subscript𝑃¯𝜃subscript¯𝜃𝐽¯𝜃\bar{\nabla}_{\theta}J(\theta)=P_{\bar{\theta}}\nabla_{\bar{\theta}}J(\bar{% \theta}).over¯ start_ARG ∇ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) = italic_P start_POSTSUBSCRIPT over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT italic_J ( over¯ start_ARG italic_θ end_ARG ) .

The dynamics involve a preconditioner, Pθ¯subscript𝑃¯𝜃P_{\bar{\theta}}italic_P start_POSTSUBSCRIPT over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT, that accelerates optimisation (Arora et al.,, 2018), which we empirically demonstrate in Section 3.3. On the left-hand side of the equation, we use ¯θJ(θ)subscript¯𝜃𝐽𝜃\bar{\nabla}_{\theta}J(\theta)over¯ start_ARG ∇ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) to denote the combined dynamics of the gradients for each layer on the dynamics of the product matrix.111Note we use ¯¯\bar{\nabla}over¯ start_ARG ∇ end_ARG because ¯J(θ)¯𝐽𝜃\bar{\nabla}J(\theta)over¯ start_ARG ∇ end_ARG italic_J ( italic_θ ) is not a gradient for any function of θ¯¯𝜃\bar{\theta}over¯ start_ARG italic_θ end_ARG; see discussion by Arora et al., (2018). This means that the effective gradient dynamics of the deep network is related to the dynamics of linear function approximation with a precondition. While the dynamics are nonlinear and non-convex, the overall dynamics are remarkably similar to that of linear function approximation, which is convex.

We now provide evidence to suggest that, despite deep linear networks being nonlinear in their gradient dynamics, they do not suffer from loss of trainability. We prove this for a special case of deep diagonal linear networks, and provide empirical evidence to support this claim in general deep linear networks.

Theorem 2.

Let fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\cdots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x be a deep diagonal linear network where θl=Diag(θl,1,,θl,d)subscript𝜃𝑙Diagsubscript𝜃𝑙1subscript𝜃𝑙𝑑\theta_{l}=\text{Diag}(\theta_{l,1},\dotso,\theta_{l,d})italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = Diag ( italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ). Then, a deep diagonal linear network converges on a sequence of tasks under the same conditions for convergence in a single task (i.e., the conditions in Arora et al.,, 2019).

Theorem 2 states that a deep diagonal linear network, a special case of general deep linear networks, can converge to a solution on each task within a sequence of tasks. The proof, provided in Appendix B, shows that the minimum singular value of the product matrix stays greater than zero, σmin(θ¯)>0subscript𝜎𝑚𝑖𝑛¯𝜃0\sigma_{min}(\bar{\theta})>0italic_σ start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT ( over¯ start_ARG italic_θ end_ARG ) > 0. Hence, deep diagonal linear networks do not suffer from loss of trainability. This result provides further evidence suggesting that linearity might be an effective inductive bias for learning continually.

While the analysis considers a special case of deep linear networks, namely deep diagonal networks, we note that this is a common setting for the analysis of deep linear networks more generally (Nacson et al.,, 2022; Even et al.,, 2023). In particular, the analysis is motivated by the fact that, under certain conditions, the evolution of the deep linear network parameters can be analyzed through the independent singular mode dynamics (Braun et al.,, 2022), which simplifies the analysis of deep linear networks to deep diagonal linear networks.

3.3 Empirical Evidence For Trainability of General Deep Linear Networks

Refer to caption
Figure 2: Trainability on a linearly separable task. The higher opacity corresponds to deeper networks, ranging from {1, 2, 4, 8, 16}. Deep linear networks sustain trainability on new tasks, with some additional depth improving trainability. Nonlinear networks, using ReLU, suffer from loss of trainability at any depth even on this simple sequence of linearly separable problems.

In the previous section, we proved that a special case of deep linear networks do not suffer from loss of trainability. We now provide additional empirical evidence that general deep linear networks do not suffer from loss of trainability. To do so, we use a linearly separable subset of the MNIST dataset (LeCun et al.,, 1998), in which the labels of each image are randomized every 100 epochs. For this experiment, the data is linearly separable so that even a linear baseline can fit the data if given enough iterations. While MNIST is a simple classification problem, memorizing random labels highlights the difficulties associated with maintaining trainability (see Lyle et al.,, 2023; Kumar et al., 2023b, ). We emphasize that the goal here is merely to validate that linear networks remain trainable in continual learning. We also provide results with traditional nonlinear neural networks on the same problem, showing that they suffer from loss of trainability in this simple problem. Later in Section 5, we extend our investigation of loss of trainability to larger-scale benchmarks.

In Figure 2, we see that deep linear networks ranging from a depth of 1111 to 16161616 can sustain trainability. Using a multi-layer perceptron with ReLU activations, deep nonlinear networks quickly reach a much higher accuracy on the first few tasks. However, due to loss of trainability, deep nonlinear networks of any depth eventually perform worse than the corresponding deep linear network. With additional epochs, the linear networks could achieve perfect accuracy on this task because it is linear separable. The number of epochs is comparatively low to showcase that, with some additional layers, a deep linear network is able to improve its trainability as new tasks are encountered.

4 Combining Linearity and Nonlinearity

In the previous section, we provided empirical and theoretical evidence that linearity provides an effective inductive bias for learning continually by avoiding loss of trainability. However, linear methods are generally not as performant as deep nonlinear networks, meaning that their sustained performance can be inadequate on complex tasks. Even deep linear networks have only linear representational power, despite the fact that the gradient dynamics are nonlinear and can lead to accelerated learning. We now seek to answer the following question:

How can the sustained trainability of linear methods be combined with
the expressive power of learned nonlinear representations?

To answer this question, we first seek to better understand the effects of replacing linear activation functions with nonlinear ones in deep networks for continual learning. We observe that deep linear networks have diversity in their hidden units, which can be induced in nonlinear activation functions by adding linearity through a weighted linear component, an idea we refer to as α𝛼\alphaitalic_α-linearization. To dynamically balance linearity and nonlinearity, we propose to use deep Fourier features for every layer in a network. We prove that such a network approximately embeds a deep linear network, a property we refer to as adaptive linearity. We demonstrate that this adaptively-linear network is plastic, maintaining trainability even on non-linearly-separable problems.

4.1 Adding Linearity to Nonlinear Activation Functions

Deep nonlinear networks can learn expressive representations because of their nonlinear activation function, but these nonlinearities can also lead to issues with trainability. Although several components of common network architectures incorporate linearity, the way in which linearity is used does not avoid loss of trainability. One example is the piecewise linearity of the ReLU activation function (Shang et al.,, 2016), ReLU(x)=max(0,x)ReLU𝑥0𝑥\texttt{ReLU}(x)=\max(0,x)ReLU ( italic_x ) = roman_max ( 0 , italic_x ), that can become saturated and prevent gradient propagation if ReLU(x)=0ReLU𝑥0\texttt{ReLU}(x)=0ReLU ( italic_x ) = 0 for most inputs x𝑥xitalic_x. While saturation is generally not a problem for learning on a single distribution, it has been noted as problematic in learning from changing distributions, for example, in reinforcement learning (Abbas et al.,, 2023).

A potential solution to saturation is to use a non-saturating activation function. Two noteworthy examples of non-saturating activation functions include a periodic activation like sin(x)sin𝑥\texttt{sin}(x)sin ( italic_x ) (Parascandolo et al.,, 2017) and leaky-ReLUα(x)=αx+(1α)ReLU(x)subscriptleaky-ReLU𝛼𝑥𝛼𝑥1𝛼ReLU𝑥\texttt{leaky-ReLU}_{\alpha}(x)=\alpha x+(1-\alpha)\texttt{ReLU}(x)leaky-ReLU start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( italic_x ) = italic_α italic_x + ( 1 - italic_α ) ReLU ( italic_x ) (Xu et al.,, 2015), both of which are zero on a set of measure zero. Surprisingly, using leaky-ReLU leads to a related issue, “unit linearization” (Lyle et al.,, 2024), in which the activation is only positive (or negative) Unlike saturated units, linearized units can provide non-zero gradients but render that unit effectively linear, limiting the expressive power of the learned representation. While unit linearization seems to suggest that loss of trainability can occur due to linearity, it is important to note that a “linearized unit” is not the same as a linear unit. This is because a linearized unit provides mostly positive (or negative) outputs, whereas a linear unit can output both positive and negative values.

We generalize the idea behind unit saturation and unit linearization to unit sign entropy, which is a metric applicable to activation functions beyond saturating and piecewise linear functions, such as periodic activation functions. Intuitively, it measures the diversity of the activations of a hidden layer.

Definition 1 (Unit Sign Entropy).

The entropy, \mathbb{H}blackboard_H, of the unit’s sign, sgn(h(x))sgn𝑥\text{sgn}(h(x))sgn ( italic_h ( italic_x ) ), on a distribution of inputs to the network, p(x)𝑝𝑥p(x)italic_p ( italic_x ), is given by (sgn(h(x)))=𝔼p(x)[sgn(h(x))]sgn𝑥subscript𝔼𝑝𝑥delimited-[]sgn𝑥\mathbb{H}\left(\text{sgn}(h(x))\right)=\mathbb{E}_{p(x)}\left[\text{sgn}(h(x)% )\right]blackboard_H ( sgn ( italic_h ( italic_x ) ) ) = blackboard_E start_POSTSUBSCRIPT italic_p ( italic_x ) end_POSTSUBSCRIPT [ sgn ( italic_h ( italic_x ) ) ].

The maximum value of unit sign entropy is 1, which occurs when the unit is positive on half the inputs. Conversely, a low sign entropy is associated with the aforementioned issues of saturation and linearization. For example, a low sign entropy for a deep network using ReLU activations means that the unit is almost always positive (P(sgn(h(x))=1)=1𝑃sgn𝑥11P\left(\text{sgn}(h(x))=1\right)=1italic_P ( sgn ( italic_h ( italic_x ) ) = 1 ) = 1, meaning it is linearized) or negative (P(sgn(h(x))=1)=0𝑃sgn𝑥10P\left(\text{sgn}(h(x))=1\right)=0italic_P ( sgn ( italic_h ( italic_x ) ) = 1 ) = 0, meaning it is saturated).

With unit sign entropy, we investigate how the leak parameter for the leaky-ReLU activation function influences training as pure linearity (α=1)𝛼1(\alpha=1)( italic_α = 1 ) is traded-off for pure nonlinearity (α=0)𝛼0(\alpha=0)( italic_α = 0 ). The idea of mixing a linearity and nonlinearity can also be generalized to an arbitrary activation function, which we refer to as the α𝛼\alphaitalic_α-linearization of an activation function.

Refer to caption
Figure 3: Trainability on a linearly separable task with α𝛼\alphaitalic_α-linearization Darker opacity lines correspond to higher values of α𝛼\alphaitalic_α. Unit sign entropy increases as α𝛼\alphaitalic_α increases (inset), leading to sustained trainability for α𝛼\alphaitalic_α-relu.
Definition 2 (α𝛼\alphaitalic_α-linearization).

The α𝛼\alphaitalic_α-linearization of an activation function ϕitalic-ϕ\phiitalic_ϕ, is denoted by ϕα(x)=αx+(1α)ϕ(x)subscriptitalic-ϕ𝛼𝑥𝛼𝑥1𝛼italic-ϕ𝑥\phi_{\alpha}(x)=\alpha x+(1-\alpha)\phi(x)italic_ϕ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( italic_x ) = italic_α italic_x + ( 1 - italic_α ) italic_ϕ ( italic_x ).

A natural hypothesis is that, as α𝛼\alphaitalic_α increases from 00 to 1111, and the network becomes more linear, loss of trainability is mitigated. We emphasize that the α𝛼\alphaitalic_α-linearization is primarily to gain insights from empirical investigation and it is not a solution to loss of trainability. This is because any benefits of α𝛼\alphaitalic_α-linearization depend on tuning α𝛼\alphaitalic_α, and even optimal tuning can lead to overly linear representations and slow training compared to nonlinear networks.

Empirical Evidence for α𝛼\alphaitalic_α-linear Plasticity

To understand the trainability issues introduced by nonlinearity, we present a case-study using sin and ReLU with different values of the linearization parameter, α𝛼\alphaitalic_α. The same experiment setup is used from Section 3.3. Referring to the results in Figure 3, we see that both ReLU and sin activation functions are able to sustain trainability for larger values of α𝛼\alphaitalic_α. This verifies the hypothesis: a larger α𝛼\alphaitalic_α provides more linearity to the network, allowing it to sustain trainability. For α𝛼\alphaitalic_α-ReLU, we also verify the hypothesis that the unit sign entropy increases for larger values of α𝛼\alphaitalic_α (inset plot). The fact that the periodic sin activation function has a high unit sign entropy despite losing trainability is particularly interesting, and we will return to this in Section 4.2. Note that, while trainability can be sustained, it is generally lower than the nonlinear networks for a large values of α𝛼\alphaitalic_α.

4.2 Adaptive-linearity by Concatenating Sinusoid Activation Functions

Using the insight that linearity promotes unit sign entropy, we explore an alternative approach to sustain trainability. In particular, we found that linearity can sustain trainability but requires tuning α𝛼\alphaitalic_α, and even optimal tuning can lead to slow learning from overly linear representations. Our approach is motivated by concatenated ReLU activations (Shang et al.,, 2016; Abbas et al.,, 2023), CReLU(z)=[ReLU(z),ReLU(z)]CReLU𝑧ReLU𝑧ReLU𝑧\texttt{CReLU}(z)=[\texttt{ReLU}(z),\texttt{ReLU}(-z)]CReLU ( italic_z ) = [ ReLU ( italic_z ) , ReLU ( - italic_z ) ], which avoids the problems from saturated units, but does not avoid the problem of low unit sign entropy. In particular, we propose using a pair of activations functions such that one activation function is always approximately linear, with a bounded error.

One way to dynamically balance the linearities and nonlinearities of a network is using periodic activation functions. This is because, due to their periodicity, the properties of the activation function can re-occur as the magnitude of the preactivations grows rather than staying constant, linear, or saturating. But, as we saw in Figure 3, a single periodic activation function like sin is not enough. Instead, we propose to use deep Fourier features, meaning that every layer in the network uses Fourier features. This is a notable departure from previous work which considers only shallow Fourier features in the first layer (Rahimi and Recht,, 2007; Tancik et al.,, 2020). In particular, each unit is a concatenation of a sinusoid basis of two elements on the same pre-activation, Fourier(z)=[sin(z),cos(z)]Fourier𝑧𝑧𝑧{\texttt{Fourier}(z)=\left[\sin(z),\cos(z)\right]}Fourier ( italic_z ) = [ roman_sin ( italic_z ) , roman_cos ( italic_z ) ]. The advantage of this approach is that a network with deep Fourier features maintains approximate linearity in some of its units.

Proposition 1.

For any z𝑧zitalic_z, there exists a linear function, Lz(x)=a(z)x+b(z)subscriptL𝑧𝑥𝑎𝑧𝑥𝑏𝑧\texttt{L}_{z}(x)=a(z)x+b(z)L start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ( italic_x ) = italic_a ( italic_z ) italic_x + italic_b ( italic_z ), such that either: |sin(x)Lz(x)|c𝑥subscriptL𝑧𝑥𝑐|\sin(x)-\texttt{L}_{z}(x)|\leq c| roman_sin ( italic_x ) - L start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ( italic_x ) | ≤ italic_c, or |cos(x)Lz(x)|c𝑥subscriptL𝑧𝑥𝑐|\cos(x)-\texttt{L}_{z}(x)|\leq c| roman_cos ( italic_x ) - L start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ( italic_x ) | ≤ italic_c, for c=2π2/28𝑐2superscript𝜋2superscript28c=\nicefrac{{\sqrt{2}\pi^{2}}}{{2^{8}}}italic_c = / start_ARG square-root start_ARG 2 end_ARG italic_π start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT end_ARG and all x[zπ/4,z+π/4]𝑥𝑧𝜋4𝑧𝜋4x\in\left[z-\nicefrac{{\pi}}{{4}},z+\nicefrac{{\pi}}{{4}}\right]italic_x ∈ [ italic_z - / start_ARG italic_π end_ARG start_ARG 4 end_ARG , italic_z + / start_ARG italic_π end_ARG start_ARG 4 end_ARG ].

An intuitive description of this is provided in Figure 1. The advantage of using two sinusoids over just a single sinusoid is that whenever cos(z)𝑧\cos(z)roman_cos ( italic_z ) is near a critical point, d/dzcos(z)0𝑑𝑑𝑧𝑧0\nicefrac{{d}}{{dz}}\cos(z)\approx 0/ start_ARG italic_d end_ARG start_ARG italic_d italic_z end_ARG roman_cos ( italic_z ) ≈ 0, we have that sin(z)z𝑧𝑧\sin(z)\approx zroman_sin ( italic_z ) ≈ italic_z, meaning that d/dzsin(z)1𝑑𝑑𝑧𝑧1\nicefrac{{d}}{{dz}}\sin(z)\approx 1/ start_ARG italic_d end_ARG start_ARG italic_d italic_z end_ARG roman_sin ( italic_z ) ≈ 1 (and vice-versa). The argument follows from an analysis of the Taylor series remainder, showing that the Taylor series of half the units in a deep Fourier layer can be approximated by a linear function, with a small error of c=2π2/280.05𝑐2superscript𝜋2superscript280.05c=\nicefrac{{\sqrt{2}\pi^{2}}}{{2^{8}}}\approx 0.05italic_c = / start_ARG square-root start_ARG 2 end_ARG italic_π start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT end_ARG ≈ 0.05. While we found that two sinusoids is sufficient, the approximation error can be further improved by concatenating additional sinusoids, at the expense of reducing the effective width of the layer.

Because each pre-activation is connected to a unit that is approximately linear, we can conclude that a deep network comprised of deep Fourier features approximately embeds a deep linear network.

Corollary 1.

A network parameterized by θ𝜃\thetaitalic_θ, with deep Fourier features, approximately embeds a deep linear network parameterized by θ𝜃\thetaitalic_θ with a bounded error.

Refer to caption
Figure 4: Trainability on a non linearly-separable task. Deep Fourier features improve and sustain their trainability when other networks cannot.

Notice that piecewise linear activations also embed a deep linear network, but these embedded deep linear networks do not use the same parameter set. For example, the deep linear network embedded by a ReLU network does not depend on any of the parameters used to compute a ReLU unit that is zero. Although the leaky-ReLU function involves every parameter, the deep linear network vanishes because the leak parameter is small, α<1𝛼1\alpha<1italic_α < 1, and hence the embedded deep linear network is multiplied by a small constant, αLsuperscript𝛼𝐿\alpha^{-L}italic_α start_POSTSUPERSCRIPT - italic_L end_POSTSUPERSCRIPT, where L𝐿Litalic_L is the depth of the network.

Empirical Evidence for Nonlinear Plasticity

We now consider a similar experimental setup from Sections 3.3 and 4.1, except we make the problem non linearly-separable by considering random label assignments on the entire dataset. Each task is more difficult because it involves memorizing more labels, and the effect of the non-stationarity is also stronger due to randomization of more datapoints. As a result, the deep linear network can no longer fit a single task well. Referring to Figure 4, the α𝛼\alphaitalic_α-linear activation functions can sustain and even improve their trainability, albeit very slowly. In contrast, using deep Fourier features within the network enables the network to easily memorize all the labels for 100 tasks. Deep Fourier features surpass the trainability of the other nonlinear baselines at initialization, CReLU and shallow Fourier features followed by ReLU. This is surprising, because deep nonlinear networks at initialization are often a gold-standard for trainability.

5 Experiments

Our experiments demonstrate the benefits of the adaptive linearity provided by deep Fourier features. While trainability was the primary focus behind our theoretical results and empirical case studies, we show that these findings generalize to other problems in continual learning. In particular, we demonstrate that networks composed of deep Fourier features are capable of learning from diminishing levels of label noise, and in class-incremental learning, in addition to sustaining trainability on random labels. The main results we present are on all of the major continual supervised learning settings considered in the plasticity literature. They build on the standard ResNet-18 architecture, widely used in practice (He et al.,, 2016).

Datasets and Non-stationarities

Our experiments use the common image classification datasets for continual learning, namely tiny-ImageNet (Le and Yang,, 2015), CIFAR10, and CIFAR100 (Krizhevsky,, 2009). We augment these datasets with commonly used non-stationarities to create continual learning problems, with the non-stationarity creating a sequence of tasks from the dataset. Specifically, following recent work on continual learning (Lee et al.,, 2024), we consider diminishing levels of label noise on each dataset: We start with half the data being corrupted by label noise and reduce the noise to clean labels over 10 tasks. Additionally, for the datasets with a larger number of classes, tiny-ImageNet and CIFAR100, we also consider the class-incremental setting: the first task involves only five classes, and five new classes are added to the existing pool of classes at the beginning of each task (Van de Ven et al.,, 2022). Other results and more details on datasets and non-stationarities considered can be found in Appendix C.

Refer to caption
Refer to caption
Refer to caption
Figure 5: Training a ResNet-18 continually with diminishing label noise. Deep Fourier features are particularly performant on complex tasks like tiny-ImageNet. Despite networks with deep Fourier features having approximately half the number of parameters, they surpass the baselines in CIFAR100 and are on-par with spectral regularization on CIFAR10.

Architecture and Baselines

We compare a ResNet-18 using only deep Fourier features against a standard ResNet-18 with ReLU activations. The network with deep Fourier features has fewer parameters because it uses a concatenation of two different activation functions, halving the effective width compared to the network with ReLU activations. This provides an advantage to the nonlinear baseline. We also include all prominent baselines that have previously been proposed to mitigate loss of plasticity in the field: L2 regularization towards zero, L2 regularization towards the initialization (Kumar et al., 2023b, ), spectral regularization (Lewandowski et al.,, 2024), Concatenated ReLU (Shang et al.,, 2016; Abbas et al.,, 2023), Dormant Neuron Recycling (ReDO, Sokar et al.,, 2023), Shrink and Perturb (Ash and Adams,, 2020), and Streaming Elastic Weight Consolidation (S-EWC, Kirkpatrick et al.,, 2017; Elsayed and Mahmood,, 2024).

5.1 Main results

Our main result demonstrates that adaptive-linearity is an effective inductive bias for continual learning. In these set of experiments, we consider the problem of sustaining test accuracy on a sequence of tasks. In addition to requiring trainability, methods must also sustain their generalization.

Diminishing Label Noise

In Figure 5, we can clearly see the benefits of deep Fourier features in the diminishing label noise setting. At the end of training on ten tasks with diminishing levels of label noise, the network with deep Fourier features was always among the methods with the highest test accuracy on the the uncorrupted test set. On the first of ten tasks, deep Fourier features could occasionally overfit to the corrupted labels leading to initially low test accuracy. However, as the label noise diminished on future tasks, the network with deep Fourier features was able to continue to learn to correct its previous poorly-generalizing predictions. In contrast, the improvements achieved by the other methods that we considered was oftentimes marginal compared to the baseline ReLU network. Two exceptions are: (i) networks with CReLU activations, which underperformed relative to the baseline network, and (ii) Shrink and Perturb, which was the best-performing baseline method for diminishing label noise. Interestingly, the performance benefit of deep Fourier features is most prominent on more complex datasets, like tiny-ImageNet.

Refer to caption
Refer to caption
Figure 6: Class incremental learning results on tiny-Imagenet (Left) and CIFAR-100 (Right). On both datasets, deep Fourier features substantially improve over most baselines.

Class-Incremental Learning

Deep Fourier features are also effective in the class-incremental setting, where later tasks involve training on a larger subset of the classes. The network is evaluated at the end of each task on the entire test set. As the network is trained on later tasks, its test set performance increases because it has access to a larger subset of the training data. In Figure 6, we see that Deep Fourier features largely outperform the baselines in this setting, particularly on tiny-ImageNet in which the first forty tasks involve training on a growing subset of the dataset and the last forty “tasks” involve training to convergence on the full dataset. 222We use quotation marks to characterize the last forty tasks because they are, in fact, a single task, as the data distribution stops changing after the first forty tasks. We call them “tasks” because of the number of iterations in which they are trained. Not only are deep Fourier features quicker to learn on earlier continual learning tasks, but they are also able to improve their generalization performance by subsequently training on the full dataset. On CIFAR100, the difference between methods is not as prominent, but we can see that deep Fourier features are still among the top-performing methods.

5.2 Sensitivity Analysis

In the previous sections, we used deep Fourier features in combination with spectral regularization to achieve high generalization. However, the theoretical analysis and case-studies that we presented earlier concerned trainability. We now present a sensitivity result to understand the relationship between trainability and generalization. Using a ResNet-18 with different activation functions, we varied the regularization strength between no regularization and high degrees of regularization. In Figure 7, we can see that deep Fourier features indeed have a high degree of trainability, sustaining trainability at different levels of regularization. However, without any regularization, deep Fourier features have a tendency to overfit. Over-fitting is a known issue for shallow Fourier features (e.g., when using Fourier features only for the input layer, Mavor-Parker et al.,, 2024). However, deep Fourier features are able to use their high trainability to learn effectively even when highly regularized. Thus, while trainability does not always lead to learning, the trainability provided by adaptive-learning still provides a useful inductive bias for continual learning.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 7: Sensitivity analysis on tiny-ImageNet, CIFAR10, and CIFAR100. Networks with deep Fourier features are highly trainable, but have a tendency to overfit without regularization, leading to high training accuracy but low test accuracy. Due to deep Fourier features being highly trainable, they are able to train with much higher regularization strengths leading to ultimately better generalization.

6 Conclusion

In this paper, we proved that linear function approximation and a special case of deep linearity are effective inductive biases for learning continually without loss of trainability. We then investigated the issues that arise from using nonlinear activation functions, namely the lack of unit sign entropy. Motivated by the effectiveness of linearity in sustaining trainability, we proposed deep Fourier features to approximately embed a deep linear network inside a deep nonlinear network. We found that deep Fourier features dynamically balance the trainability afforded by linearity and the effectiveness of nonlinearity, thus providing an effective inductive bias for learning continually. Experimentally, we demonstrate that networks with deep Fourier features provide benefits for continual learning across every dataset we consider. Importantly, networks with deep Fourier features are effective plastic learners because their trainability allows for higher regularization strengths that leads to improved and sustained generalization over the course of learning.

References

  • Abbas et al., (2023) Abbas, Z., Zhao, R., Modayil, J., White, A., and Machado, M. C. (2023). Loss of plasticity in continual deep reinforcement learning. In Conference on Lifelong Learning Agents.
  • Abel et al., (2024) Abel, D., Barreto, A., Van Roy, B., Precup, D., van Hasselt, H. P., and Singh, S. (2024). A definition of continual reinforcement learning. Advances in Neural Information Processing Systems.
  • Agrawal et al., (2021) Agrawal, A., Barratt, S., and Boyd, S. (2021). Learning convex optimization models. Journal of Automatica Sinica, 8(8):1355–1364.
  • Arora et al., (2019) Arora, S., Cohen, N., Golowich, N., and Hu, W. (2019). A convergence analysis of gradient descent for deep linear neural networks. In International Conference on Learning Representations.
  • Arora et al., (2018) Arora, S., Cohen, N., and Hazan, E. (2018). On the optimization of deep networks: Implicit acceleration by overparameterization. In International Conference on Machine Learning.
  • Ash and Adams, (2020) Ash, J. T. and Adams, R. P. (2020). On Warm-Starting Neural Network Training. In Advances in Neural Information Processing Systems.
  • Ba et al., (2016) Ba, J. L., Kiros, J. R., and Hinton, G. E. (2016). Layer normalization. CoRR, abs/1607.06450v1.
  • Bah et al., (2022) Bah, B., Rauhut, H., Terstiege, U., and Westdickenberg, M. (2022). Learning deep linear neural networks: Riemannian gradient flows and convergence to global minimizers. Information and Inference: A Journal of the IMA.
  • Bernacchia et al., (2018) Bernacchia, A., Lengyel, M., and Hennequin, G. (2018). Exact natural gradient in deep linear networks and its application to the nonlinear case. Advances in Neural Information Processing Systems.
  • Boyd and Vandenberghe, (2004) Boyd, S. P. and Vandenberghe, L. (2004). Convex optimization. Cambridge university press.
  • Braun et al., (2022) Braun, L., Dominé, C., Fitzgerald, J., and Saxe, A. (2022). Exact learning dynamics of deep linear networks with prior knowledge. Advances in Neural Information Processing Systems.
  • Chou et al., (2024) Chou, H.-H., Gieshoff, C., Maly, J., and Rauhut, H. (2024). Gradient descent for deep matrix factorization: Dynamics and implicit bias towards low rank. Applied and Computational Harmonic Analysis.
  • Cohen et al., (2017) Cohen, G., Afshar, S., Tapson, J., and Van Schaik, A. (2017). Emnist: Extending mnist to handwritten letters. In International Joint Conference on Neural Networks (IJCNN).
  • Dohare et al., (2024) Dohare, S., Hernandez-Garcia, J. F., Lan, Q., Rahman, P., Mahmood, A. R., and Sutton, R. S. (2024). Loss of plasticity in deep continual learning. Nature, 632(8026):768–774.
  • Dohare et al., (2021) Dohare, S., Sutton, R. S., and Mahmood, A. R. (2021). Continual backprop: Stochastic gradient descent with persistent randomness. CoRR, abs/2108.06325v3.
  • Elsayed and Mahmood, (2024) Elsayed, M. and Mahmood, A. R. (2024). Addressing loss of plasticity and catastrophic forgetting in continual learning. In International Conference on Learning Representations.
  • Even et al., (2023) Even, M., Pesme, S., Gunasekar, S., and Flammarion, N. (2023). (s)GD over diagonal linear networks: Implicit bias, large stepsizes and edge of stability. In Advances in Neural Information Processing Systems.
  • Garrigos and Gower, (2023) Garrigos, G. and Gower, R. M. (2023). Handbook of Convergence Theorems for (Stochastic) Gradient Methods. CoRR, abs/2301.11235v3.
  • Glorot and Bengio, (2010) Glorot, X. and Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. In International Conference on Artificial Intelligence and Statistics.
  • He et al., (2015) He, K., Zhang, X., Ren, S., and Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In International Conference on Computer Vision.
  • He et al., (2016) He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep residual learning for image recognition. In Conference on Computer Vision and Pattern Recognition.
  • Huh, (2020) Huh, D. (2020). Curvature-corrected learning dynamics in deep neural networks. In International Conference on Machine Learning.
  • Ioffe and Szegedy, (2015) Ioffe, S. and Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning.
  • Jacot et al., (2018) Jacot, A., Gabriel, F., and Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. Advances in Neural Information Processing Systems.
  • Kingma and Ba, (2015) Kingma, D. P. and Ba, J. (2015). Adam: A Method for Stochastic Optimization. In International Conference on Learning Representations.
  • Kirkpatrick et al., (2017) Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A. A., Milan, K., Quan, J., Ramalho, T., Grabska-Barwinska, A., et al. (2017). Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences, 114(13):3521–3526.
  • Kleinman et al., (2024) Kleinman, M., Achille, A., and Soatto, S. (2024). Critical learning periods emerge even in deep linear networks. In International Conference on Learning Representations.
  • Konidaris et al., (2011) Konidaris, G., Osentoski, S., and Thomas, P. (2011). Value function approximation in reinforcement learning using the fourier basis. In AAAI Conference on Artificial Intelligence.
  • Krizhevsky, (2009) Krizhevsky, A. (2009). Learning multiple layers of features from tiny images. Technical report, University of Toronto.
  • (30) Kumar, S., Marklund, H., Rao, A., Zhu, Y., Jeon, H. J., Liu, Y., and Van Roy, B. (2023a). Continual Learning as Computationally Constrained Reinforcement Learning. CoRR, abs/2307.04345.
  • (31) Kumar, S., Marklund, H., and Roy, B. V. (2023b). Maintaining plasticity via regenerative regularization. CoRR, abs/2308.11958v1.
  • Kunin et al., (2024) Kunin, D., Raventós, A., Dominé, C., Chen, F., Klindt, D., Saxe, A., and Ganguli, S. (2024). Get rich quick: exact solutions reveal how unbalanced initializations promote rapid feature learning. CoRR, abs/2406.06158v1.
  • Le and Yang, (2015) Le, Y. and Yang, X. (2015). Tiny imagenet visual recognition challenge.
  • LeCun et al., (1998) LeCun, Y., Cortes, C., and Burges, C. (1998). MNIST handwritten digit database. ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist.
  • Lee et al., (2024) Lee, H., Cho, H., Kim, H., Kim, D., Min, D., Choo, J., and Lyle, C. (2024). Slow and steady wins the race: Maintaining plasticity with hare and tortoise networks. In International Conference on Machine Learning.
  • Lee et al., (2019) Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. (2019). Wide neural networks of any depth evolve as linear models under gradient descent. Advances in Neural Information Processing Systems.
  • Lewandowski et al., (2024) Lewandowski, A., Kumar, S., Schuurmans, D., György, A., and Machado, M. C. (2024). Learning Continually by Spectral Regularization. CoRR, abs/2406.06811v1.
  • Li and Pathak, (2021) Li, A. C. and Pathak, D. (2021). Functional regularization for reinforcement learning via learned fourier features. In Advances in Neural Information Processing Systems.
  • Liu et al., (2023) Liu, Y., Kuang, X., and Roy, B. V. (2023). A Definition of Non-Stationary Bandits. CoRR, abs/2302.12202v2.
  • Lyle et al., (2022) Lyle, C., Rowland, M., and Dabney, W. (2022). Understanding and preventing capacity loss in reinforcement learning. In International Conference on Learning Representations.
  • Lyle et al., (2024) Lyle, C., Zheng, Z., Khetarpal, K., van Hasselt, H., Pascanu, R., Martens, J., and Dabney, W. (2024). Disentangling the Causes of Plasticity Loss in Neural Networks. CoRR, abs/2402.18762v1.
  • Lyle et al., (2023) Lyle, C., Zheng, Z., Nikishin, E., Avila Pires, B., Pascanu, R., and Dabney, W. (2023). Understanding plasticity in neural networks. In International Conference on Machine Learning.
  • Mavor-Parker et al., (2024) Mavor-Parker, A. N., Sargent, M. J., Barry, C., Griffin, L., and Lyle, C. (2024). Frequency and Generalisation of Periodic Activation Functions in Reinforcement Learning. CoRR, abs/2407.06756v1.
  • Nacson et al., (2022) Nacson, M. S., Ravichandran, K., Srebro, N., and Soudry, D. (2022). Implicit bias of the step size in linear diagonal neural networks. In International Conference on Machine Learning.
  • Parascandolo et al., (2017) Parascandolo, G., Huttunen, H., and Virtanen, T. (2017). Taming the waves: sine as activation function in deep neural networks.
  • Parisi et al., (2019) Parisi, G. I., Kemker, R., Part, J. L., Kanan, C., and Wermter, S. (2019). Continual lifelong learning with neural networks: A review. Neural networks, 113:54–71.
  • Rahimi and Recht, (2007) Rahimi, A. and Recht, B. (2007). Random features for large-scale kernel machines. Advances in Neural Information Processing Systems.
  • Ring, (1994) Ring, M. B. (1994). Continual learning in reinforcement environments. The University of Texas at Austin.
  • Saxe et al., (2014) Saxe, A., McClelland, J., and Ganguli, S. (2014). Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. In International Conference on Learning Represenatations.
  • Shang et al., (2016) Shang, W., Sohn, K., Almeida, D., and Lee, H. (2016). Understanding and improving convolutional neural networks via concatenated rectified linear units. In International Conference on Machine Learning.
  • Sokar et al., (2023) Sokar, G., Agarwal, R., Castro, P. S., and Evci, U. (2023). The dormant neuron phenomenon in deep reinforcement learning. In International Conference on Machine Learning.
  • Tancik et al., (2020) Tancik, M., Srinivasan, P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., Barron, J., and Ng, R. (2020). Fourier features let networks learn high frequency functions in low dimensional domains. Advances in Neural Information Processing Systems.
  • Thrun, (1998) Thrun, S. (1998). Lifelong learning algorithms. In Learning to Learn, pages 181–209. Springer.
  • Van de Ven et al., (2022) Van de Ven, G. M., Tuytelaars, T., and Tolias, A. S. (2022). Three types of incremental learning. Nature Machine Intelligence, 4(12):1185–1197.
  • Xiao et al., (2017) Xiao, H., Rasul, K., and Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. CoRR, abs/1708.07747.
  • Xu et al., (2015) Xu, B., Wang, N., Chen, T., and Li, M. (2015). Empirical Evaluation of Rectified Activations in Convolutional Network. CoRR, abs/1505.00853v2.
  • Yang et al., (2022) Yang, G., Ajay, A., and Agrawal, P. (2022). Overcoming the spectral bias of neural value approximation. In International Conference on Learning Representations.
  • Yang et al., (2023) Yang, G., Simon, J. B., and Bernstein, J. (2023). A Spectral Condition for Feature Learning. CoRR, abs/2310.17813v2.
  • Ziyin et al., (2022) Ziyin, L., Li, B., and Meng, X. (2022). Exact solutions of a deep linear network. Advances in Neural Information Processing Systems.

Appendix A Additional Details

A.1 Assumptions for Trainabiilty of Linear Function Approximation

Neither squared nor cross-entropy loss are μ𝜇\muitalic_μ-strongly convex in general. However, the assumption is satisfied under reasonable conditions in practice and in the problem settings considered in this paper.

For regression, denote the features for task τ𝜏\tauitalic_τ as Xτd×Nsubscript𝑋𝜏superscript𝑑𝑁X_{\tau}\in\mathbb{R}^{d\times N}italic_X start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT where d𝑑ditalic_d is the feature dimension and N𝑁Nitalic_N is the sample size. For linear function approximation, the Hessian is the outer products of the data matrix, θ2Jτreg(θ)=XτXτd×dsubscriptsuperscript2𝜃subscriptsuperscript𝐽𝑟𝑒𝑔𝜏𝜃subscript𝑋𝜏superscriptsubscript𝑋𝜏topsuperscript𝑑𝑑\nabla^{2}_{\theta}J^{reg}_{\tau}(\theta)=X_{\tau}X_{\tau}^{\top}\in\mathbb{R}% ^{d\times d}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT italic_r italic_e italic_g end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) = italic_X start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT. Thus, the squared loss is strongly-convex if the data is full rank. This is satisfied in high dimensional image classification problems, which is what we consider.

For binary classification, the Hessian involves an additional diagonal matrix of the predictions for each datapoint,

θ2Jτclass(θ)=XτDτXτd×d,subscriptsuperscript2𝜃subscriptsuperscript𝐽𝑐𝑙𝑎𝑠𝑠𝜏𝜃subscript𝑋𝜏subscript𝐷𝜏superscriptsubscript𝑋𝜏topsuperscript𝑑𝑑\nabla^{2}_{\theta}J^{class}_{\tau}(\theta)=X_{\tau}D_{\tau}X_{\tau}^{\top}\in% \mathbb{R}^{d\times d},∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT italic_c italic_l italic_a italic_s italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) = italic_X start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT ,

where Dτ=Diag(p1,τ,,pN,τ)subscript𝐷𝜏Diagsubscript𝑝1𝜏subscript𝑝𝑁𝜏D_{\tau}=\text{Diag}(p_{1,\tau},\dotso,p_{N,\tau})italic_D start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT = Diag ( italic_p start_POSTSUBSCRIPT 1 , italic_τ end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_N , italic_τ end_POSTSUBSCRIPT ), pi,τ=2σ(fθ(xi,τ))(1σ(fθ(xi,τ)))subscript𝑝𝑖𝜏2𝜎subscript𝑓𝜃subscript𝑥𝑖𝜏1𝜎subscript𝑓𝜃subscript𝑥𝑖𝜏p_{i,\tau}=2\sigma(f_{\theta}(x_{i,\tau}))(1-\sigma(f_{\theta}(x_{i,\tau})))italic_p start_POSTSUBSCRIPT italic_i , italic_τ end_POSTSUBSCRIPT = 2 italic_σ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i , italic_τ end_POSTSUBSCRIPT ) ) ( 1 - italic_σ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i , italic_τ end_POSTSUBSCRIPT ) ) ), and σ𝜎\sigmaitalic_σ is the sigmoid function. If the prediction becomes sufficiently confident, σ(fθ(xi))=1𝜎subscript𝑓𝜃subscript𝑥𝑖1\sigma(f_{\theta}(x_{i}))=1italic_σ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) = 1, then there can be rank deficiency in the Hessian. However, because each task is only budgeted a finite number of iterations this bounds the predictions away from 1111.

A.2 Related Work Regarding Trainability of Deep Linear Networks

Some authors have suggested deep linear networks suffer from a related issue, namely that critical learning periods also occur for deep linear networks (Kleinman et al.,, 2024). Unlike the focus on loss of trainability in this work where the entire network is trained, these critical learning periods are due to winner-take-all dynamics due to manufactured defects in one half of the linear network, for which the other half compensates.

Finally, we note that some previous work have found that gradient dynamics have a low rank bias for deep linear networks (Chou et al.,, 2024). One important assumption that these works make is that the neural network weights are initialized identically across layers, θj=αθ1subscript𝜃𝑗𝛼subscript𝜃1\theta_{j}=\alpha\theta_{1}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_α italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Our analysis assumes that the initialization uses small random values, such as those used in practice with common neural network initialization schemes (Glorot and Bengio,, 2010; He et al.,, 2015).

A.3 Details For Deep Linear Setup

To simplify notation, without loss of generality, we consider a deep linear network without the bias terms, θ={θ1,,θL}𝜃subscript𝜃1subscript𝜃𝐿\theta=\{\theta_{1},\dotso,\theta_{L}\}italic_θ = { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT }, and fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\cdots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x. We denote the product of weight matrices, or simply product matrix, as θ¯=θLθL1θ1¯𝜃subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1\bar{\theta}=\theta_{L}\theta_{L-1}\cdots\theta_{1}over¯ start_ARG italic_θ end_ARG = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ⋯ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, which allows us to write the deep linear network in terms of the product matrix: fθ(x)=θ¯xsubscript𝑓𝜃𝑥¯𝜃𝑥f_{\theta}(x)=\bar{\theta}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = over¯ start_ARG italic_θ end_ARG italic_x. The problem setup we use for the deep linear analysis follows previous work (Huh,, 2020), and we provide additional details in Appendix A.3. We consider the squared error, Jτ(θ)=𝔼(x,y)pτ[yθ¯x]22subscript𝐽𝜏𝜃subscript𝔼similar-to𝑥𝑦subscript𝑝𝜏superscriptsubscriptdelimited-[]norm𝑦¯𝜃𝑥22J_{\tau}(\theta)=\mathbb{E}_{(x,y)\sim p_{\tau}}\left[\|y-\bar{\theta}x\|% \right]_{2}^{2}italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_y - over¯ start_ARG italic_θ end_ARG italic_x ∥ ] start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. and we assume that the observations are whitened to simplify the analysis,

Σx=𝔼[xx]=𝐈subscriptΣ𝑥𝔼delimited-[]𝑥superscript𝑥top𝐈\Sigma_{x}=\mathbb{E}\left[xx^{\top}\right]=\mathbf{I}roman_Σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = blackboard_E [ italic_x italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = bold_I, focusing on the case where the targets y𝑦yitalic_y are changing during continual learning. Then we can write the squared error as

J(θ)=Tr[ΔτΔτ],𝐽𝜃Trdelimited-[]subscriptΔ𝜏superscriptsubscriptΔ𝜏topJ(\theta)=\text{Tr}\left[\Delta_{\tau}\Delta_{\tau}^{\top}\right],italic_J ( italic_θ ) = Tr [ roman_Δ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ,

where Δτ=θτθ¯subscriptΔ𝜏subscriptsuperscript𝜃𝜏¯𝜃\Delta_{\tau}=\theta^{\star}_{\tau}-\bar{\theta}roman_Δ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - over¯ start_ARG italic_θ end_ARG is the distance to the optimal linear predictor, θτ=Σyx,τ=𝔼x,ypτ[yx]Σxsubscriptsuperscript𝜃𝜏subscriptΣ𝑦𝑥𝜏subscript𝔼similar-to𝑥𝑦subscript𝑝𝜏delimited-[]𝑦superscript𝑥topsubscriptΣ𝑥\theta^{\star}_{\tau}=\Sigma_{yx,\tau}=\mathbb{E}_{x,y\sim p_{\tau}}[yx^{\top}% ]\Sigma_{x}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT = roman_Σ start_POSTSUBSCRIPT italic_y italic_x , italic_τ end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_x , italic_y ∼ italic_p start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_y italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] roman_Σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT.

The convergence of gradient descent for general deep linear networks requires an assumption on the deficiency margin, which is used to ensure that the solution found by a deep linear network, in terms of the product matrix, is full rank (Arora et al.,, 2019). That is, the deep linear network converges if the minimum singular value of the product matrix stays positive, σmin(θ¯)>0subscript𝜎𝑚𝑖𝑛¯𝜃0\sigma_{min}(\bar{\theta})>0italic_σ start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT ( over¯ start_ARG italic_θ end_ARG ) > 0.

We now show that a diagonal linear network maintains a positive minimum singular value under continual learning. This is a simplified setting for analysis, where we assume that the weight matrices are diagonal and thus the input, hidden, and output dimension are all equal. Let fθ(x)subscript𝑓𝜃𝑥f_{\theta}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) be a diagonal linear network, defined by a set of diagonal weight matrices, θl=Diag(θl,1,,θl,d)subscript𝜃𝑙Diagsubscript𝜃𝑙1subscript𝜃𝑙𝑑\theta_{l}=\text{Diag}(\theta_{l,1},\dotso,\theta_{l,d})italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = Diag ( italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ). The output of the diagonal linear network is the product of the diagonal matrices, fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\dots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT … italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x. Then the product matrix is also a diagonal matrix, whose diagonals are the products of the parameters of each layer, θ¯=Diag(l=1Lθl,1,,l=1Lθl,d):=Diag(θ¯1,,θ¯d)¯𝜃𝐷𝑖𝑎𝑔superscriptsubscriptproduct𝑙1𝐿subscript𝜃𝑙1superscriptsubscriptproduct𝑙1𝐿subscript𝜃𝑙𝑑assign𝐷𝑖𝑎𝑔subscript¯𝜃1subscript¯𝜃𝑑\bar{\theta}=Diag(\prod_{l=1}^{L}\theta_{l,1},\dotso,\prod_{l=1}^{L}\theta_{l,% d}):=Diag(\bar{\theta}_{1},\dotso,\bar{\theta}_{d})over¯ start_ARG italic_θ end_ARG = italic_D italic_i italic_a italic_g ( ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ) := italic_D italic_i italic_a italic_g ( over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ). The minimum singular value of a diagonal matrix is the minimum of its absolute values, σmin(θ¯)=mini|θ¯i|subscript𝜎𝑚𝑖𝑛¯𝜃subscript𝑖subscript¯𝜃𝑖\sigma_{min}(\bar{\theta})=\min_{i}|\bar{\theta}_{i}|italic_σ start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT ( over¯ start_ARG italic_θ end_ARG ) = roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. Thus, we must show that the minimum absolute value of the product matrix is never zero.

Lemma 1.

Consider a deep diagonal linear network, fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\dots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT … italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x and θl=Diag(θl,1,,θl,d)subscript𝜃𝑙Diagsubscript𝜃𝑙1subscript𝜃𝑙𝑑\theta_{l}=\text{Diag}(\theta_{l,1},\dotso,\theta_{l,d})italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = Diag ( italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ). Then, under gradient descent dynamics, θl,i(t)=θl,i(t)subscriptsuperscript𝜃𝑡𝑙𝑖subscriptsuperscript𝜃𝑡superscript𝑙𝑖\theta^{(t)}_{l,i}=\theta^{(t)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT iff θl,i(0)=θl,i(0)subscriptsuperscript𝜃0𝑙𝑖subscriptsuperscript𝜃0superscript𝑙𝑖\theta^{(0)}_{l,i}=\theta^{(0)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT for llsuperscript𝑙𝑙l^{\prime}\not=litalic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_l.

The proof of this proposition, and the next, can be found in Appendix B. This first proposition states that two parameters that are initialized to different values, such as by a random initialization, will never have the same value under gradient descent. Conversely, if the parameters are initialized identically, then they will stay the same value under gradient descent. This means that, in particular, two parameters will never be simultaneously zero.

Lemma 2.

Denote a deep diagonal linear network as fθ(x)=Diag(θ¯1,,θ¯d)xsubscript𝑓𝜃𝑥𝐷𝑖𝑎𝑔subscript¯𝜃1subscript¯𝜃𝑑𝑥f_{\theta}(x)=Diag(\bar{\theta}_{1},\dotso,\bar{\theta}_{d})xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_D italic_i italic_a italic_g ( over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) italic_x where θ¯i=l=1Lθl,isubscript¯𝜃𝑖superscriptsubscriptproduct𝑙1𝐿subscript𝜃𝑙𝑖\bar{\theta}_{i}=\prod_{l=1}^{L}\theta_{l,i}over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT. Then, under gradient descent dynamics, θ¯i(t)=θ¯i(t+1)=0subscriptsuperscript¯𝜃𝑡𝑖subscriptsuperscript¯𝜃𝑡1𝑖0\bar{\theta}^{(t)}_{i}=\bar{\theta}^{(t+1)}_{i}=0over¯ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 iff two (or more) components are zero, θl,i(t)=θl,i(t)=0subscriptsuperscript𝜃𝑡𝑙𝑖subscriptsuperscript𝜃𝑡superscript𝑙𝑖0\theta^{(t)}_{l,i}=\theta^{(t)}_{l^{\prime},i}=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = 0, for llsuperscript𝑙𝑙l^{\prime}\not=litalic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_l.

While the analysis considers a special case of deep linear networks, namely deep diagonal networks, we note that this is a common setting for the analysis of deep linear networks more generally (Nacson et al.,, 2022; Even et al.,, 2023). In particular, the analysis is motivated by the fact that, under certain conditions, the evolution of the deep linear network parameters can be analyzed through the independent singular mode dynamics (Saxe et al.,, 2014), which simplify the analysis of deep linear networks to deep diagonal linear networks. The target function being learned, y(x)=θxsuperscript𝑦𝑥superscript𝜃𝑥y^{\star}(x)=\theta^{\star}xitalic_y start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) = italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT italic_x, is represented in terms of the singular-value decomposition, θ=USV=j=1rsiuivisuperscript𝜃superscript𝑈superscript𝑆superscript𝑉superscriptsubscript𝑗1𝑟subscript𝑠𝑖subscript𝑢𝑖superscriptsubscript𝑣𝑖top\theta^{\star}=U^{\star}S^{\star}V^{\star}=\sum_{j=1}^{r}s_{i}u_{i}v_{i}^{\top}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. We also assume that the neural network has a fixed hidden dimension, so that θ1d×din,θLdout×d,θ1<l<Ld×dformulae-sequencesubscript𝜃1superscript𝑑subscript𝑑𝑖𝑛formulae-sequencesubscript𝜃𝐿superscriptsubscript𝑑𝑜𝑢𝑡𝑑subscript𝜃1𝑙𝐿superscript𝑑𝑑\theta_{1}\in\mathbb{R}^{d\times d_{in}},\theta_{L}\in\mathbb{R}^{d_{out}% \times d},\theta_{1<l<L}\in\mathbb{R}^{d\times d}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT 1 < italic_l < italic_L end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT; and we apply the singular value decomposition to the function approximator’s parameters, θl=UlSlVldout×dhsubscript𝜃𝑙subscript𝑈𝑙subscript𝑆𝑙subscript𝑉𝑙superscriptsubscript𝑑𝑜𝑢𝑡subscript𝑑\theta_{l}=U_{l}S_{l}V_{l}\in\mathbb{R}^{d_{out}\times d_{h}}italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. To simplify the product of weight matrices, we assume Vi+1=Uisubscript𝑉𝑖1subscript𝑈𝑖V_{i+1}=U_{i}italic_V start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, V1=Vsubscript𝑉1superscript𝑉V_{1}=V^{\star}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and UL=Usubscript𝑈𝐿superscript𝑈U_{L}=U^{\star}italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The simplifying result is that the squared error loss can be expressed entirely in terms of the singular values, yxi=L1θix2Si=L1Sl2proportional-tosuperscriptnormsuperscript𝑦𝑥superscriptsubscriptproduct𝑖𝐿1subscript𝜃𝑖𝑥2superscriptnormsuperscript𝑆superscriptsubscriptproduct𝑖𝐿1subscript𝑆𝑙2\|y^{\star}x-\prod_{i=L}^{1}\theta_{i}x\|^{2}\propto\|S^{\star}-\prod_{i=L}^{1% }S_{l}\|^{2}∥ italic_y start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT italic_x - ∏ start_POSTSUBSCRIPT italic_i = italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∝ ∥ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - ∏ start_POSTSUBSCRIPT italic_i = italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, which is equivalent to our analysis of the deep diagonal network, as the matrix of singular values is a diagonal matrix. These decoupled learning dynamics are closely approximated by networks with small random weights and they persist under gradient flows (Huh,, 2020).

Appendix B Proofs

Proof of Theorem 1.

We first present the result for two tasks and we then generalize it to an arbitary number of tasks. Let the linear weights learned on the first task be θ(T)superscript𝜃𝑇\theta^{(T)}italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT, with the corresponding unique global minimum denoted by θ1subscriptsuperscript𝜃1\theta^{\star}_{1}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. The solution found on the first task is used as an initialization on the second task, which will end at θ(2T)superscript𝜃2𝑇\theta^{(2T)}italic_θ start_POSTSUPERSCRIPT ( 2 italic_T ) end_POSTSUPERSCRIPT, with the corresponding unique global minimum denoted by θ2subscriptsuperscript𝜃2\theta^{\star}_{2}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. We start from the known suboptimality gap for gradient descent on the second task (Garrigos and Gower,, 2023):

J2(θ(2T))J2(θ2)<θ2θ(T)2αT.subscript𝐽2superscript𝜃2𝑇subscript𝐽2superscriptsubscript𝜃2superscriptnormsuperscriptsubscript𝜃2superscript𝜃𝑇2𝛼𝑇\displaystyle J_{2}(\theta^{(2T)})-J_{2}(\theta_{2}^{\star})<\frac{\|\theta_{2% }^{\star}-\theta^{(T)}\|^{2}}{\alpha T}.italic_J start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( 2 italic_T ) end_POSTSUPERSCRIPT ) - italic_J start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) < divide start_ARG ∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T end_ARG . (1)

We upper bound the distance from the initialization on the second task, θ(T)superscript𝜃𝑇\theta^{(T)}italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT, to the optimum, θ2subscriptsuperscript𝜃2\theta^{\star}_{2}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, by

θ2θ(T)2<θ2θ12+θ1θ(T)2<θ2θ12+(1αμ)Tθ1θ02.superscriptnormsuperscriptsubscript𝜃2superscript𝜃𝑇2superscriptnormsuperscriptsubscript𝜃2superscriptsubscript𝜃12superscriptnormsuperscriptsubscript𝜃1superscript𝜃𝑇2superscriptnormsuperscriptsubscript𝜃2superscriptsubscript𝜃12superscript1𝛼𝜇𝑇superscriptnormsuperscriptsubscript𝜃1subscript𝜃02\displaystyle\|\theta_{2}^{\star}-\theta^{(T)}\|^{2}<\|\theta_{2}^{\star}-% \theta_{1}^{\star}\|^{2}+\|\theta_{1}^{\star}-\theta^{(T)}\|^{2}<\|\theta_{2}^% {\star}-\theta_{1}^{\star}\|^{2}+(1-\alpha\mu)^{T}\|\theta_{1}^{\star}-\theta_% {0}\|^{2}.∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT < ∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT < ∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (2)

Where the last inequality uses the assumption that the objective function is μ𝜇\muitalic_μ-strongly convex. We upper bound the suboptimality gap on the second task by a quantity independent of θ(T)superscript𝜃𝑇\theta^{(T)}italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT:

J2(θ(2T))J2(θ2)<θ2θ(T)2αT<θ2θ12+(1αμ)Tθ1θ02αT,subscript𝐽2superscript𝜃2𝑇subscript𝐽2superscriptsubscript𝜃2superscriptnormsuperscriptsubscript𝜃2superscript𝜃𝑇2𝛼𝑇superscriptnormsuperscriptsubscript𝜃2superscriptsubscript𝜃12superscript1𝛼𝜇𝑇superscriptnormsuperscriptsubscript𝜃1subscript𝜃02𝛼𝑇\displaystyle J_{2}(\theta^{(2T)})-J_{2}(\theta_{2}^{\star})<\frac{\|\theta_{2% }^{\star}-\theta^{(T)}\|^{2}}{\alpha T}<\frac{\|\theta_{2}^{\star}-\theta_{1}^% {\star}\|^{2}+(1-\alpha\mu)^{T}\|\theta_{1}^{\star}-\theta_{0}\|^{2}}{\alpha T},italic_J start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( 2 italic_T ) end_POSTSUPERSCRIPT ) - italic_J start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) < divide start_ARG ∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T end_ARG < divide start_ARG ∥ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T end_ARG , (3)

which implies that the parameter value learned on the previous task does not influence training on the new task beyond a dependence on the initial distance. This is true for an arbitrary number of tasks:

Jτ(θ(τT))Jτ(θτ)<k=1τ(1αμ)T(kτ)θkθk12αT<2D(1αμ)TαT(1(1αμ)T),subscript𝐽𝜏superscript𝜃𝜏𝑇subscript𝐽𝜏superscriptsubscript𝜃𝜏superscriptsubscript𝑘1𝜏superscript1𝛼𝜇𝑇𝑘𝜏superscriptnormsuperscriptsubscript𝜃𝑘superscriptsubscript𝜃𝑘12𝛼𝑇2𝐷superscript1𝛼𝜇𝑇𝛼𝑇1superscript1𝛼𝜇𝑇\displaystyle J_{\tau}(\theta^{(\tau T)})-J_{\tau}(\theta_{\tau}^{\star})<% \frac{\sum_{k=1}^{\tau}(1-\alpha\mu)^{T(k-\tau)}\|\theta_{k}^{\star}-\theta_{k% -1}^{\star}\|^{2}}{\alpha T}<\frac{2D(1-\alpha\mu)^{T}}{\alpha T(1-(1-\alpha% \mu)^{T})},italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_τ italic_T ) end_POSTSUPERSCRIPT ) - italic_J start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) < divide start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T ( italic_k - italic_τ ) end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T end_ARG < divide start_ARG 2 italic_D ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG italic_α italic_T ( 1 - ( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) end_ARG , (4)

where we denote θ0=θ0superscriptsubscript𝜃0subscript𝜃0\theta_{0}^{\star}=\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The last inequality follows from our assumption that the distance between the task solutions, θkθk12<2Dsuperscriptnormsuperscriptsubscript𝜃𝑘superscriptsubscript𝜃𝑘122𝐷\|\theta_{k}^{\star}-\theta_{k-1}^{\star}\|^{2}<2D∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT < 2 italic_D, is bounded and using a geometric sum in (1αμ)Tsuperscript1𝛼𝜇𝑇(1-\alpha\mu)^{T}( 1 - italic_α italic_μ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. ∎

Proof of Lemma 1.

\Rightarrow We first prove the lemma in the forward direction:

Assuming that θl,i(t)=θl,i(t)subscriptsuperscript𝜃𝑡𝑙𝑖subscriptsuperscript𝜃𝑡superscript𝑙𝑖\theta^{(t)}_{l,i}=\theta^{(t)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT for llsuperscript𝑙𝑙l^{\prime}\not=litalic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_l, we will show that θl,i(t1)=θl,i(t1)subscriptsuperscript𝜃𝑡1𝑙𝑖subscriptsuperscript𝜃𝑡1superscript𝑙𝑖\theta^{(t-1)}_{l,i}=\theta^{(t-1)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT.

Writing the gradient update for θl,i(t)superscriptsubscript𝜃𝑙𝑖𝑡\theta_{l,i}^{(t)}italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT with a fixed step-size α𝛼\alphaitalic_α, we have that

θl,i(t)subscriptsuperscript𝜃𝑡𝑙𝑖\displaystyle\theta^{(t)}_{l,i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT =θl,i(t1)αθl,iJ(θ)absentsubscriptsuperscript𝜃𝑡1𝑙𝑖𝛼subscriptsubscript𝜃𝑙𝑖𝐽𝜃\displaystyle=\theta^{(t-1)}_{l,i}-\alpha\nabla_{\theta_{l,i}}J(\theta)= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_J ( italic_θ ) (5)
=θl,i(t1)αfθ(fθ(x),y)θl,ifθ(x)absentsubscriptsuperscript𝜃𝑡1𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptsubscript𝜃𝑙𝑖subscript𝑓𝜃𝑥\displaystyle=\theta^{(t-1)}_{l,i}-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x)% ,y)\nabla_{\theta_{l,i}}f_{\theta}(x)= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) (6)
=θ(t1)αfθ(fθ(x),y)θl,ij=1Lθj,i(t1)xabsentsuperscript𝜃𝑡1𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptsubscript𝜃𝑙𝑖superscriptsubscriptproduct𝑗1𝐿superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=\theta^{(t-1)}-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x),y)% \nabla_{\theta_{l,i}}\prod_{j=1}^{L}\theta_{j,i}^{(t-1)}x= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∏ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x (7)
=θl,i(t1)αfθ(fθ(x),y)jlθj,i(t1)x.absentsubscriptsuperscript𝜃𝑡1𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=\theta^{(t-1)}_{l,i}-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x)% ,y)\prod_{j\not=l}\theta_{j,i}^{(t-1)}x.= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x . (8)

Similarly, the gradient update for θl,isubscript𝜃superscript𝑙𝑖\theta_{l^{\prime},i}italic_θ start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT is

θl,i(t)subscriptsuperscript𝜃𝑡superscript𝑙𝑖\displaystyle\theta^{(t)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT =θl,i(t1)αfθ(fθ(x),y)jlθj,i(t1)xabsentsubscriptsuperscript𝜃𝑡1superscript𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscript𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=\theta^{(t-1)}_{l^{\prime},i}-\alpha\nabla_{f_{\theta}}\ell(f_{% \theta}(x),y)\prod_{j\not=l^{\prime}}\theta_{j,i}^{(t-1)}x= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x (10)

Using our assumption that θl,i(t)=θl,i(t)superscriptsubscript𝜃𝑙𝑖𝑡superscriptsubscript𝜃superscript𝑙𝑖𝑡\theta_{l,i}^{(t)}=\theta_{l^{\prime},i}^{(t)}italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, we set the two updates equal to eachother:

θl,i(t1)αfθ(fθ(x),y)jlθj,i(t1)x=θl,i(t1)αfθ(fθ(x),y)jlθj,i(t1)x.subscriptsuperscript𝜃𝑡1𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥subscriptsuperscript𝜃𝑡1superscript𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscript𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle\theta^{(t-1)}_{l,i}-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x),% y)\prod_{j\not=l}\theta_{j,i}^{(t-1)}x=\theta^{(t-1)}_{l^{\prime},i}-\alpha% \nabla_{f_{\theta}}\ell(f_{\theta}(x),y)\prod_{j\not=l^{\prime}}\theta_{j,i}^{% (t-1)}x.italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x = italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x . (12)

We can simplify both sides of the equations, where the LHS is

θl,i(t1)αfθ(fθ(x),y)jθj,i(t1)xθl,i(t1)subscriptsuperscript𝜃𝑡1𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscriptsubscript𝜃𝑗𝑖𝑡1𝑥superscriptsubscript𝜃superscript𝑙𝑖𝑡1\displaystyle\theta^{(t-1)}_{l,i}-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x),% y)\prod_{j}\theta_{j,i}^{(t-1)}\frac{x}{\theta_{l^{\prime},i}^{(t-1)}}italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT divide start_ARG italic_x end_ARG start_ARG italic_θ start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT end_ARG (13)
=θl,i(t1)(1αfθ(fθ(x),y)jθj,i(t1)xθl,i(t1)θl,i(t1)).absentsubscriptsuperscript𝜃𝑡1𝑙𝑖1𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscriptsubscript𝜃𝑗𝑖𝑡1𝑥superscriptsubscript𝜃superscript𝑙𝑖𝑡1superscriptsubscript𝜃𝑙𝑖𝑡1\displaystyle=\theta^{(t-1)}_{l,i}\left(1-\alpha\nabla_{f_{\theta}}\ell(f_{% \theta}(x),y)\prod_{j}\theta_{j,i}^{(t-1)}\frac{x}{\theta_{l^{\prime},i}^{(t-1% )}\theta_{l,i}^{(t-1)}}\right).= italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT ( 1 - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT divide start_ARG italic_x end_ARG start_ARG italic_θ start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT end_ARG ) . (14)

Similarly, the RHS of the equation is

θl,i(t1)(1αfθ(fθ(x),y)jθj,i(t1)xθl,i(t1)θl,i(t1)).subscriptsuperscript𝜃𝑡1superscript𝑙𝑖1𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscriptsubscript𝜃𝑗𝑖𝑡1𝑥superscriptsubscript𝜃superscript𝑙𝑖𝑡1superscriptsubscript𝜃𝑙𝑖𝑡1\displaystyle\theta^{(t-1)}_{l^{\prime},i}\left(1-\alpha\nabla_{f_{\theta}}% \ell(f_{\theta}(x),y)\prod_{j}\theta_{j,i}^{(t-1)}\frac{x}{\theta_{l^{\prime},% i}^{(t-1)}\theta_{l,i}^{(t-1)}}\right).italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT ( 1 - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT divide start_ARG italic_x end_ARG start_ARG italic_θ start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT end_ARG ) . (15)

Notice that both expressions in the parenthesis on the LHS and RHS are equal. Thus, θl,i(t1)=θl,i(t1)subscriptsuperscript𝜃𝑡1superscript𝑙𝑖subscriptsuperscript𝜃𝑡1𝑙𝑖\theta^{(t-1)}_{l^{\prime},i}=\theta^{(t-1)}_{l,i}italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT

\Leftarrow The reverse direction follows directly by following the above argument in reverse.

Proof of Lemma 2.

\Rightarrow We first prove the lemma in the forward direction:

Assuming that θ¯i(t+1)=θ¯i(t)=0subscriptsuperscript¯𝜃𝑡1𝑖subscriptsuperscript¯𝜃𝑡𝑖0\bar{\theta}^{(t+1)}_{i}=\bar{\theta}^{(t)}_{i}=0over¯ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0, we will show that θl,i(t)=θl,i(t)=0subscriptsuperscript𝜃𝑡𝑙𝑖subscriptsuperscript𝜃𝑡superscript𝑙𝑖0\theta^{(t)}_{l,i}=\theta^{(t)}_{l^{\prime},i}=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = 0.

We proceed by contradiction, and assume that only a single component is zero, that is θl,i(t)=0subscriptsuperscript𝜃𝑡superscript𝑙𝑖0\theta^{(t)}_{l^{\prime},i}=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = 0 and θl,i(t)0subscriptsuperscript𝜃𝑡𝑙𝑖0\theta^{(t)}_{l,i}\not=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT ≠ 0 for ll𝑙superscript𝑙l\not=l^{\prime}italic_l ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. We will show that the gradient update will ensure that θi(t+1)0subscriptsuperscript𝜃𝑡1𝑖0\theta^{(t+1)}_{i}\not=0italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≠ 0

First, consider the update to θl,i(t)subscriptsuperscript𝜃𝑡superscript𝑙𝑖\theta^{(t)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT,

θl,i(t+1)subscriptsuperscript𝜃𝑡1superscript𝑙𝑖\displaystyle\theta^{(t+1)}_{l^{\prime},i}italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT =θl,i(t)αfθ(fθ(x),y)jlθj,i(t1)xabsentsubscriptsuperscript𝜃𝑡superscript𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscript𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=\theta^{(t)}_{l^{\prime},i}-\alpha\nabla_{f_{\theta}}\ell(f_{% \theta}(x),y)\prod_{j\not=l^{\prime}}\theta_{j,i}^{(t-1)}x= italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x (16)
=αfθ(fθ(x),y)jlθj,i(t1)xabsent𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗superscript𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=-\alpha\nabla_{f_{\theta}}\ell(f_{\theta}(x),y)\prod_{j\not=l^{% \prime}}\theta_{j,i}^{(t-1)}x= - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x (17)

Because we assumed that θl,i(t)0subscriptsuperscript𝜃𝑡𝑙𝑖0\theta^{(t)}_{l,i}\not=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT ≠ 0 for ll𝑙superscript𝑙l\not=l^{\prime}italic_l ≠ italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we have that jlθj,i(t1)0subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡10\prod_{j\not=l}\theta_{j,i}^{(t-1)}\not=0∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ≠ 0. Thus θl,i(t+1)0subscriptsuperscript𝜃𝑡1superscript𝑙𝑖0\theta^{(t+1)}_{l^{\prime},i}\not=0italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT ≠ 0

Next consider the update to θl,i(t)subscriptsuperscript𝜃𝑡𝑙𝑖\theta^{(t)}_{l,i}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT,

θl,i(t+1)subscriptsuperscript𝜃𝑡1𝑙𝑖\displaystyle\theta^{(t+1)}_{l,i}italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT =θl,i(t)αfθ(fθ(x),y)jlθj,i(t1)xabsentsubscriptsuperscript𝜃𝑡superscript𝑙𝑖𝛼subscriptsubscript𝑓𝜃subscript𝑓𝜃𝑥𝑦subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡1𝑥\displaystyle=\theta^{(t)}_{l^{\prime},i}-\alpha\nabla_{f_{\theta}}\ell(f_{% \theta}(x),y)\prod_{j\not=l}\theta_{j,i}^{(t-1)}x= italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) , italic_y ) ∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT italic_x (18)
=θl,i(t)absentsubscriptsuperscript𝜃𝑡superscript𝑙𝑖\displaystyle=\theta^{(t)}_{l^{\prime},i}= italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT (19)

Where the last line follows from the fact that jlθj,i(t1)=0subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡10\prod_{j\not=l}\theta_{j,i}^{(t-1)}=0∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT = 0 because θl,i(t)=0subscriptsuperscript𝜃𝑡superscript𝑙𝑖0\theta^{(t)}_{l^{\prime},i}=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = 0.

Thus, we have shown that θl,i(t+1)0superscriptsubscript𝜃𝑙𝑖𝑡10\theta_{l,i}^{(t+1)}\not=0italic_θ start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≠ 0 for all l𝑙litalic_l, and hence, θ¯(t+1)0superscript¯𝜃𝑡10\bar{\theta}^{(t+1)}\not=0over¯ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≠ 0 which is a contradiction.

\Leftarrow The reverse direction follows from the assumption directly. If two components are both equal to zero, θl,i(t)=θl,i(t)=0subscriptsuperscript𝜃𝑡𝑙𝑖subscriptsuperscript𝜃𝑡superscript𝑙𝑖0\theta^{(t)}_{l,i}=\theta^{(t)}_{l^{\prime},i}=0italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT = 0, then every sub-product is zero, jlθj,i(t1)subscriptproduct𝑗𝑙superscriptsubscript𝜃𝑗𝑖𝑡1\prod_{j\not=l}\theta_{j,i}^{(t-1)}∏ start_POSTSUBSCRIPT italic_j ≠ italic_l end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT and so is the entire product, j=1Lθj,i(t1)superscriptsubscriptproduct𝑗1𝐿superscriptsubscript𝜃𝑗𝑖𝑡1\prod_{j=1}^{L}\theta_{j,i}^{(t-1)}∏ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT. ∎

Proof of Theorem 2.

We now show that a diagonal linear network maintains a positive minimum singular value under continual learning. This is a simplified setting for analysis, where we assume that the weight matrices are diagonal and thus the input, hidden, and output dimension are all equal. Let fθ(x)subscript𝑓𝜃𝑥f_{\theta}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) be a diagonal linear network, defined by a set of diagonal weight matrices, θl=Diag(θl,1,,θl,d)subscript𝜃𝑙Diagsubscript𝜃𝑙1subscript𝜃𝑙𝑑\theta_{l}=\text{Diag}(\theta_{l,1},\dotso,\theta_{l,d})italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = Diag ( italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ). The output of the diagonal linear network is the product of the diagonal matrices, fθ(x)=θLθL1θ1xsubscript𝑓𝜃𝑥subscript𝜃𝐿subscript𝜃𝐿1subscript𝜃1𝑥f_{\theta}(x)=\theta_{L}\theta_{L-1}\dots\theta_{1}xitalic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT … italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x. Then the product matrix is also a diagonal matrix, whose diagonals are the products of the parameters of each layer, θ¯=Diag(l=1Lθl,1,,l=1Lθl,d):=Diag(θ¯1,,θ¯d)¯𝜃𝐷𝑖𝑎𝑔superscriptsubscriptproduct𝑙1𝐿subscript𝜃𝑙1superscriptsubscriptproduct𝑙1𝐿subscript𝜃𝑙𝑑assign𝐷𝑖𝑎𝑔subscript¯𝜃1subscript¯𝜃𝑑\bar{\theta}=Diag(\prod_{l=1}^{L}\theta_{l,1},\dotso,\prod_{l=1}^{L}\theta_{l,% d}):=Diag(\bar{\theta}_{1},\dotso,\bar{\theta}_{d})over¯ start_ARG italic_θ end_ARG = italic_D italic_i italic_a italic_g ( ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , … , ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ) := italic_D italic_i italic_a italic_g ( over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ). The minimum singular value of a diagonal matrix is the minimum of its absolute values, σmin(θ¯)=mini|θ¯i|subscript𝜎𝑚𝑖𝑛¯𝜃subscript𝑖subscript¯𝜃𝑖\sigma_{min}(\bar{\theta})=\min_{i}|\bar{\theta}_{i}|italic_σ start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT ( over¯ start_ARG italic_θ end_ARG ) = roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | over¯ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. Thus, we must show that the minimum absolute value of the product matrix is never zero.

This follows immediately from Lemma 1 and Lemma 2. Taken together, these two lemmas state that with a random initialization and under gradient dynamics, a diagonal linear network will not have more than one parameter equal to zero. This means that the minimum singular value of the product matrix will never be zero. Thus, we have shown that a diagonal linear network trained with gradient descent, if initialized appropriately, will be able to converge on any given task in a sequence. ∎

Proof of Proposition 1.

We prove this by considering the remainder of a Taylor series on the given interval. Due to periodicity of sin(z)𝑧\sin(z)roman_sin ( italic_z ) and cos(z)𝑧\cos(z)roman_cos ( italic_z ), we can consider z[π,π]𝑧𝜋𝜋z\in[-\pi,\pi]italic_z ∈ [ - italic_π , italic_π ] without loss of generality. We can further consider two cases, either z[π,3π/4][π/4,π/4][3π/4,π]𝑧𝜋3𝜋4𝜋4𝜋43𝜋4𝜋z\in[-\pi,\nicefrac{{-3\pi}}{{4}}]\cup[\nicefrac{{-\pi}}{{4}},\nicefrac{{\pi}}% {{4}}]\cup[\nicefrac{{3\pi}}{{4}},\pi]italic_z ∈ [ - italic_π , / start_ARG - 3 italic_π end_ARG start_ARG 4 end_ARG ] ∪ [ / start_ARG - italic_π end_ARG start_ARG 4 end_ARG , / start_ARG italic_π end_ARG start_ARG 4 end_ARG ] ∪ [ / start_ARG 3 italic_π end_ARG start_ARG 4 end_ARG , italic_π ] or h[3π/4,π/4][π/4,3π/4]3𝜋4𝜋4𝜋43𝜋4h\in[\nicefrac{{-3\pi}}{{4}},\nicefrac{{-\pi}}{{4}}]\cup[\nicefrac{{\pi}}{{4}}% ,\nicefrac{{3\pi}}{{4}}]italic_h ∈ [ / start_ARG - 3 italic_π end_ARG start_ARG 4 end_ARG , / start_ARG - italic_π end_ARG start_ARG 4 end_ARG ] ∪ [ / start_ARG italic_π end_ARG start_ARG 4 end_ARG , / start_ARG 3 italic_π end_ARG start_ARG 4 end_ARG ]. In the first case, z𝑧zitalic_z is near a critical point of cos(z)𝑧\cos(z)roman_cos ( italic_z ) and in the second case z𝑧zitalic_z is near a critical point of sin(z)𝑧\sin(z)roman_sin ( italic_z ).

We focus on a particular subcase, where z[π/4,π/4]𝑧𝜋4𝜋4z\in[\nicefrac{{-\pi}}{{4}},\nicefrac{{\pi}}{{4}}]italic_z ∈ [ / start_ARG - italic_π end_ARG start_ARG 4 end_ARG , / start_ARG italic_π end_ARG start_ARG 4 end_ARG ], which is close to a critical point of cos(z)𝑧\cos(z)roman_cos ( italic_z ), but far from a critical point of sin(h)\sin(h)roman_sin ( italic_h ) (the other cases follow a similar argument).

Because we know that z[π/4,π/4]𝑧𝜋4𝜋4z\in[\nicefrac{{-\pi}}{{4}},\nicefrac{{\pi}}{{4}}]italic_z ∈ [ / start_ARG - italic_π end_ARG start_ARG 4 end_ARG , / start_ARG italic_π end_ARG start_ARG 4 end_ARG ], by Taylor’s theorem it follows that sin(z)=z+R1,0(z)𝑧𝑧subscript𝑅10𝑧\sin(z)=z+R_{1,0}(z)roman_sin ( italic_z ) = italic_z + italic_R start_POSTSUBSCRIPT 1 , 0 end_POSTSUBSCRIPT ( italic_z ), where R1,0(z)=sin(2)(c)2z2subscript𝑅10𝑧superscript2𝑐2superscript𝑧2R_{1,0}(z)=\frac{\sin^{(2)}(c)}{2}z^{2}italic_R start_POSTSUBSCRIPT 1 , 0 end_POSTSUBSCRIPT ( italic_z ) = divide start_ARG roman_sin start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ( italic_c ) end_ARG start_ARG 2 end_ARG italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is the 1st degree Taylor remainder centered at a=0𝑎0a=0italic_a = 0 for some c[π/4,π/4]𝑐𝜋4𝜋4c\in[\nicefrac{{-\pi}}{{4}},\nicefrac{{\pi}}{{4}}]italic_c ∈ [ / start_ARG - italic_π end_ARG start_ARG 4 end_ARG , / start_ARG italic_π end_ARG start_ARG 4 end_ARG ]. In the case of a sinusoid, this can be upperbounded, |R1,0(z)|=|sin(c)2z2|<182(π/4)2subscript𝑅10𝑧𝑐2superscript𝑧2182superscript𝜋42|R_{1,0}(z)|=|\frac{-\sin(c)}{2}z^{2}|<\frac{1}{8\sqrt{2}}(\nicefrac{{\pi}}{{4% }})^{2}| italic_R start_POSTSUBSCRIPT 1 , 0 end_POSTSUBSCRIPT ( italic_z ) | = | divide start_ARG - roman_sin ( italic_c ) end_ARG start_ARG 2 end_ARG italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | < divide start_ARG 1 end_ARG start_ARG 8 square-root start_ARG 2 end_ARG end_ARG ( / start_ARG italic_π end_ARG start_ARG 4 end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, using the fact that |z|<π/4𝑧𝜋4|z|<\nicefrac{{\pi}}{{4}}| italic_z | < / start_ARG italic_π end_ARG start_ARG 4 end_ARG and sin(c)<1/2𝑐12\sin(c)<\nicefrac{{1}}{{\sqrt{2}}}roman_sin ( italic_c ) < / start_ARG 1 end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG.

Thus, when cos(z)𝑧\cos(z)roman_cos ( italic_z ) is close to a critical point, sin(z)𝑧\sin(z)roman_sin ( italic_z ) is approximately linear. A similar argument holds for the other case, when sin(z)𝑧\sin(z)roman_sin ( italic_z ) is close to a critical point, cos(z)𝑧\cos(z)roman_cos ( italic_z ) is approximately linear. In this other case, the error incurred is the same.

Proof of Corollary 1.

We prove this claim using induction.

Base case: We want to show that a single layer that outputs Fourier features embeds a deep linear network. Using Proposition 1, there exists one unit for each pre-activation that is approximately linear. Because each pre-activation is used in an approximately-linear unit, the single layer approximately embeds a deep linear network using all of its parameters.

Induction step: Assume a deep Fourier network with depth L1𝐿1L-1italic_L - 1 embeds a deep linear network, we prove that adding an additional deep Fourier layer retains the embedded deep linear network. There are two cases to consider, corresponding to the units of the additional deep Fourier layer which are approximately-linear and the other units that are not approximately-linear

Case 1 (approximately-linear units): For the additional deep Fourier layer, the set of approximately-linear units already embeds a deep linear network. Because linearity is closed under composition, the composition of the additional deep Fourier layer and the deep Fourier network with depth L1𝐿1L-1italic_L - 1 simply adds an additional linear layer to the embedded deep linear network, increasing its depth to L𝐿Litalic_L.

Case 2 (other units): For the units that are not well-approximated by a linear function, we can treat them as if they were separate inputs to the deep Fourier network with depth L1𝐿1L-1italic_L - 1. The network’s parameters associated with those inputs are, by the inductive hypothesis, already embedded in the deep linear network.

Note that case 1 embeds the parameters of the additional deep Fourier layer into the deep Fourier network. Case 2 states that the parameters of the network associated with the nonlinear units of the additional deep Fourier layer are already embedded in the deep Fourier network by construction.

Thus, a neural network composed of deep Fourier layers embeds a deep linear network. ∎

Appendix C Empirical Details

All of our experiments use 10 seeds and we report the standard error of the mean in the figures. The optimiser used for all experiments was Adam, and after a sweep on each of the datasets over [0.005,0.001,0.0005]0.0050.0010.0005[0.005,0.001,0.0005][ 0.005 , 0.001 , 0.0005 ], we found that α=0.0005𝛼0.0005\alpha=0.0005italic_α = 0.0005 was most performant.

We used the Adam optimizer (Kingma and Ba,, 2015) for all experiments, settling on the default learning rate of 0.0010.0010.0010.001 after evaluating [0.005,0.001,0.0005]0.0050.0010.0005[0.005,0.001,0.0005][ 0.005 , 0.001 , 0.0005 ]. Results are presented with standard error of the mean, indicated by shaded regions, based on 10 random seeds.

Dataset specifications and non-stationarity conditions:

  • For MNIST, Fashion MNIST and EMNIST: we use a random sample of 25600256002560025600 of the observations and a batch size of 256256256256 (unless otherwise indicated, such as the linearly separable experiment).

  • For CIFAR10 and CIFAR1100: Full 50000 images for training, 1000 test images for validation, rest for testing. The batch size used was 250. Random label non-stationarity: 20 epochs per task, 30 tasks total. Labelnoise non-stationarity: 80 epochs, 10 tasks. Class incremental learning: 6000 iterations per task, 80 tasks. Note that the datasets on different tasks in the class incremental setting can have different sizes, and so epochs are not comparable.

  • tiny-ImageNet: All 100000 images for training, 10000 for validation, 10000 for testing as per predetermined split. The batch size used was 250. Random label non-stationarity: 20 epochs per task, 30 tasks total. Pixel permutation non-stationarity: 60 epochs, 100 tasks. Class incremental learning: 10000 iterations per task, 80 tasks. Note that the datasets on different tasks in the class incremental setting can have different sizes, and so epochs are not comparable.

Neural Network Architectures

For tiny-ImageNet, CIFAR10, CIFAR100, and SVHN2: We utilized standard ResNet-18 with batch normalization and a standard tiny Vision Transformer. The smaller datasets use an MLP with different widths and depths, as specified in the scaling section.

Appendix D Additional Experiments

These additional experiments validate the benefits of adaptive-linearity as a means of improving trainability. The experiments use the following datasets for continual supervised learning: MNIST (LeCun et al.,, 1998), Fashion MNIST (Xiao et al.,, 2017), and EMNIST (Cohen et al.,, 2017). We focus primarily on the problem of trainability, and thus consider random label non-stationarity, in which the labels are randomly assigned to each observation and must be memorized on each task. This type of non-stationarity is particular difficulty in sustaining trainability in continual learning (Lyle et al.,, 2023; Kumar et al., 2023b, ). We compare our adaptively-linear network against a corresponding nonlinear feed-forward neural network with ReLU activations with the same depth. Because the adaptively-linear network uses a concatenation of two different activation functions, the adaptively-linear network has half the width of the nonlinear network and less parameters, which provides an advantage to the nonlinear baseline.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 8: Trainability across different datasets and epochs per tasks. Nonlinear networks lose their trainability, whereas adaptively-linear networks improve and sustain their trainability

D.1 Adaptively-Linear Networks are Highly Trainable

The main result of this appendix is presented in Figure 8. Across different datasets, almost-linear networks are highly trainable, either achieving high accuracy and maintaining it on easier tasks, such as MNIST, or improving their trainability on new tasks, such as on Fashion MNIST. In contrast, the nonlinear network suffered from loss of trainability in each of the problems that we studied. This is not surprising, as loss of trainability is a well-documented issue for nonlinear networks without some additional method designed to mitigate it (Dohare et al.,, 2021; Lyle et al.,, 2022; Kumar et al., 2023b, ; Elsayed and Mahmood,, 2024).

D.2 Methods for Improving Trainability

Given that a nonlinear network is unable to maintain its trainability in isolation, we investigate whether recently proposed methods for mitigating loss of trainability are able to make up for the difference in performance between an adaptively linear network and a nonlinear network. We investigate two categories of mitigators for loss of plasticity: (i) regularization and (ii) normalization layers.

Regularization

Refer to caption
Refer to caption
Figure 9: Hyperparameter Sensitivity Analysis. Adaptively-linear networks seem to not benefit from regularization. While nonlinear networks are more trainable with regularization, their performance is still worse than the adaptively-linear network.

Loss of plasticity occurs in nonlinear networks when they are not regularized. Thus, we compare the performance of the nonlinear network and the adaptively-linear network with varying regularization strengths. In particular, we use the recently proposed L2 regularization towards the initialization (Kumar et al., 2023b, ), because it addresses the issue of sensitivity towards zero common to L2 regularization towards zero. In Figure 9, we find that regularization does improve the trainability of nonlinear networks, validating previous empirical findings. However, we found that almost-linear networks do not benefit substantially from regularization. That is, almost-linear network with a smaller regularization strength always outperformed the nonlinear network.

Layer Normalization

Training deep neural networks typically involve normalization layers, either Batch Normalization (Ioffe and Szegedy,, 2015) or Layer Normalization (Ba et al.,, 2016). Recently, it was demonstrated that layer normalization is an effective mitigator for loss of trainability (Lyle et al.,, 2024). We investigate whether trainability can be improved with the addition of normalization layers, for both the nonlinear and adaptively-linear network. In Figure 10, we found that layer normalization increases performance but that loss of trainability can still occur with a nonlinear network. In addition to Layer Normalization, we also tried a linear version of LayerNorm which uses a stop-gradient on the standard deviation to maintain linearity, which improved training speed in some instances.

Refer to caption
Refer to caption
Refer to caption
Figure 10: Comparison of trainability with Layer Normalization. Nonlinear networks are more trainable with Layer Normalization, but adaptively-linear networks learn faster and achieve better accuracy, particularly with linearized Layer Norm.

D.3 Scaling Properties of Almost-Linear Networks

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 11: Scaling Neural Network Width and Depth. (Top) Due to the concatenation used by the activation function in adaptively-linear networks, they scale particularly well with width. (Bottom) Deeper adaptively-linear networks also lead to improved average end of task performance.

Width Scaling

Another source of linearity recently proposed is an increasing width of the neural network, causing their parameter dynamics evolves as linear models in the limit (Lee et al.,, 2019). We investigate whether an increase in width can close the gap between the trainability of the nonlinear network and the almost-linear network. In Figure 11 (Top), we found that adaptively-linear networks scale particularly well with width, whereas width seems to have little effect on the trainability of nonlinear networks. Thus, our results suggest that increasing the width of a neural network does not necessarily impact its trainability, at least not to the width values we considered.

Depth Scaling

Neural networks in supervised learning tend to scale with depth, allowing them to learn more complex predictions. We investigate whether the depth scaling of almost-linear networks also leads to similar improvements in continual learning. In Figure 11 (Bottom), we found that adaptively-linear networks do improve with additional depth, but the degree of improvement was not as pronounced as scaling the width.