SMART: Robust and E Fficient Fine-Tuning For Pre-Trained Natural Language Models Through Principled Regularized Optimization
SMART: Robust and E Fficient Fine-Tuning For Pre-Trained Natural Language Models Through Principled Regularized Optimization
SMART: Robust and E Fficient Fine-Tuning For Pre-Trained Natural Language Models Through Principled Regularized Optimization
Abstract
Transfer learning has fundamentally changed the landscape of natural language processing
(NLP). Many state-of-the-art models are first pre-trained on a large text corpus and then fine-
tuned on downstream tasks. However, due to limited data resources from downstream tasks and
the extremely high complexity of pre-trained models, aggressive fine-tuning often causes the
fine-tuned model to overfit the training data of downstream tasks and fail to generalize to unseen
data. To address such an issue in a principled manner, we propose a new learning framework
for robust and efficient fine-tuning for pre-trained models to attain better generalization perfor-
mance. The proposed framework contains two important ingredients: 1. Smoothness-inducing
regularization, which effectively manages the complexity of the model; 2. Bregman proximal point
optimization, which is an instance of trust-region methods and can prevent aggressive updating.
Our experiments show that the proposed framework achieves new state-of-the-art performance on
a number of NLP tasks including GLUE, SNLI, SciTail and ANLI. Moreover, it also outperforms the
state-of-the-art T5 model, which is the largest pre-trained model containing 11 billion parameters,
on GLUE. 1
1 Introduction
The success of natural language processing (NLP) techniques relies on huge amounts of labeled
data in many applications. However, large amounts of labeled data are usually prohibitive or
expensive to obtain. To address this issue, researchers have resorted to transfer learning.
Transfer learning considers the scenario, where we have limited labeled data from the target
domain for a certain task, but we have relevant tasks with a large amount of data from different
domains (also known as out-of-domain data). The goal is to transfer the knowledge from the high-
resource domains to the low-resource target domain. Here we are particularly interested in the
popular two-stage transfer learning framework (Pan and Yang, 2009). The first stage is pre-training,
where a high-capacity model is trained for the out-of-domain high-resource relevant tasks. The
* Work was done during Haoming Jiang’s internship at Microsoft Dynamics 365 AI. Haoming Jiang and Tuo Zhao
are affiliated with Georgia Institute of Technology. Pengcheng He and Weizhu Chen are affiliated with Microsoft
Dynamics 365 AI. Xiaodong Liu and Jianfeng Gao are affiliated with Microsoft Research. Emails: [email protected],
{penhe,wzchen}@microsoft.com, {xiaodl,jfgao}@microsoft.com, [email protected].
1 https://github.com/namisan/mt-dnn
1
second stage is fine-tuning, where the high-capacity model is adapted to the low-resource task in
the target domain.
For many applications in NLP, most popular transfer learning methods choose to pre-train a
large language model, e.g., ELMo (Peters et al., 2018), GPT (Radford et al., 2019) and BERT (Devlin
et al., 2019). Such a language model can capture general semantic and syntactic information that
can be further used in downstream NLP tasks. The language model is particularly attractive,
because it can be trained in a completely unsupervised manner with huge amount of unlabeled
data, which are extremely cheap to fetch from internet nowadays. The resulting extremely large
multi-domain text corpus allows us to train huge language models. To the best of our knowledge,
by far the largest language model, T5, has an enormous size of about 11 billion parameters (Raffel
et al., 2019).
For the second fine-tuning stage, researchers adapt the pre-trained language model to the
target task/domain. They usually replace the top layer of the language model by a task/domain-
specific sub-network, and then continue to train the new model with the limited data of the
target task/domain. Such a fine-tuning approach accounts for the low-resource issue in the target
task/domain, and has achieved state-of-the-art performance in many popular NLP benchmarks
(Devlin et al., 2019; Liu et al., 2019c; Yang et al., 2019; Lan et al., 2019; Dong et al., 2019; Raffel
et al., 2019).
Due to the limited data from the target task/domain and the extremely high complexity of
the pre-trained model, aggressive fine-tuning often makes the adapted model overfit the training
data of the target task/domain and therefore does not generalize well to unseen data. To mitigate
this issue, the fine-tuning methods often rely on hyper-parameter tuning heuristics. For example,
Howard and Ruder (2018) use a heuristic learning rate schedule and gradually unfreeze the layers
of the language model to improve the fine-tune performance; Peters et al. (2019) give a different
suggestion that they only adapt certain layers and freeze the others; Houlsby et al. (2019); Stickland
and Murray (2019) propose to add additional layers to the pre-trained model and fine-tune both of
them or only the additional layers. However, these methods require significant tuning efforts.
To fully harness the power of fine-tuning in a more principled manner, we propose a new
learning framework for robust and efficient fine-tuning on the pre-trained language models
through regularized optimization techniques. Specifically, our framework consists of two important
ingredients for preventing overfitting:
(I) To effectively control the extremely high complexity of the model, we propose a Smoothness-
inducing Adversarial Regularization technique. Our proposed regularization is motivated by local
shift sensitivity in existing literature on robust statistics. Such regularization encourages the output
of the model not to change much, when injecting a small perturbation to the input. Therefore, it
enforces the smoothness of the model, and effectively controls its capacity (Mohri et al., 2018).
(II) To prevent aggressive updating, we propose a class of Bregman Proximal Point Optimization
methods. Our proposed optimization methods introduce a trust-region-type regularization (Conn
et al., 2000) at each iteration, and then update the model only within a small neighborhood of
the previous iterate. Therefore, they can effectively prevent aggressive updating and stabilize the
fine-tuning process.
2
We compare our proposed method with several state-of-the-art competitors proposed in Zhu
et al. (2020); Liu et al. (2019b,c); Lan et al. (2019); Raffel et al. (2019) and show that our proposed
method significantly improves the training stability and generalization, and achieves comparable
or better performance on multiple NLP tasks. We highlight that our single model with 356M
parameters (without any ensemble) can achieve three state-of-the-art results on GLUE, even com-
pared with all existing ensemble models and the T5 model (Raffel et al., 2019), which contains 11
billion parameters. Furthermore, we also demonstrate that the proposed framework complements
with SOTA fine-tuning methods (Liu et al., 2019b) and outperforms the T5 model.
We summarize our contribution as follows: 1. We introduce the smoothness-inducing adversar-
ial regularization and proximal point optimization into large scale language model fine-tuning; 2.
We achieve state-of-the-art results on several popular NLP benchmarks (e.g., GLUE, SNLI, SciTail,
and ANLI).
Notation: We use f (x; θ) to denote a mapping f associated with the parameter θ from input
sentences x to an output space, where the output is a multi-dimensional probability simplex for
classification tasks and a scalar for regression tasks. ΠA denotes the projection operator to the set
P
A. DKL (P ||Q) = k pk log(pk /qk ) denotes the KL-divergence of two discrete distributions P and Q
with the associated parameters of pk and qk , respectively.
2 Background
The transformer models were originally proposed in Vaswani et al. (2017) for neural machine
translation. Their superior performance motivated Devlin et al. (2019) to propose a bidirectional
transformer-based language model named BERT. Specifically, Devlin et al. (2019) pre-trained the
BERT model using a large corpus without any human annotation through unsupervised learning
tasks. BERT motivated many follow-up works to further improve the pre-training by introducing
new unsupervised learning tasks (Yang et al., 2019; Dong et al., 2019; Joshi et al., 2020), enlarging
model size (Lan et al., 2019; Raffel et al., 2019), enlarging training corpora (Liu et al., 2019c; Yang
et al., 2019; Raffel et al., 2019) and multi-tasking (Liu et al., 2019a,b).
The pre-trained language model is then adapted to downstream tasks and further fine-tuned.
Specifically, the top layer of the language model can be replaced by a task-specific layer and then
continue to train on downstream tasks. To prevent overfitting, existing heuristics include choosing
a small learning rate or a triangular learning rate schedule, and a small number of iterations, and
other fine-tuning tricks mentioned in Howard and Ruder (2018); Peters et al. (2019); Houlsby et al.
(2019); Stickland and Murray (2019).
Our proposed regularization technique is related to several existing works (Miyato et al., 2018;
Zhang et al., 2019; Shu et al., 2018). These works consider similar regularization techniques, but
target at other applications with different motivations, e.g., semi-supervised learning, unsupervised
domain adaptation and harnessing adversarial examples in image classification.
Our proposed optimization technique covers a large class of Bregman proximal point methods
in existing literature on optimization, including vanilla proximal point method (Rockafellar, 1976),
generalized proximal point method (Teboulle, 1997; Eckstein, 1993), accelerated proximal point
method, and other variants (Güler, 1991, 1992; Parikh et al., 2014).
3
There is a related fine-tuning method – FreeLB (Zhu et al., 2020), which adapted a robust
adversarial training method. However, our framework focuses on the local smoothness, leading to
a significant performance improvement. More discussion and comparison are provided in Section
4.
1 Pn
L(θ) = n i=1 `(f (xi ; θ), yi ),
and `(·, ·) is the loss function depending on the target task, λs > 0 is a tuning parameter, and Rs (θ)
is the smoothness-inducing adversarial regularizer. Here we define Rs (θ) as
n
1X
Rs (θ) = max `s (f (e
xi ; θ), f (xi ; θ)),
n xi −xi kp ≤
ke
i=1
where > 0 is a tuning parameter. Note that for classification tasks, f (·; θ) outputs a probability
simplex and `s is chosen as the symmetrized KL-divergence, i.e.,
For regression tasks, f (·; θ) outputs a scalar and `s is chosen as the squared loss, i.e., `s (p, q) = (p−q)2 .
Note that the computation of Rs (θ) involves a maximization problem and can be solved efficiently
by projected gradient ascent.
We remark that the proposed smoothness-inducing adversarial regularizer was first used in
Miyato et al. (2018) for semi-supervised learning with p = 2, and then in Shu et al. (2018) for
unsupervised domain adaptation with p = 2, and more recently in Zhang et al. (2019) for harnessing
the adversarial examples in image classification with p = ∞. To the best of our knowledge, we are
the first applying such a regularizer to fine-tuning of pre-trained language models.
2 The complete name of our proposed method is SMAR3 T2 , but we use SMART for notational simplicity.
4
The smoothness-inducing adversarial regularizer is essentially measuring the local Lipschitz
continuity of f under the metric `s . More precisely speaking, the output of f does not change
much if we inject a small perturbation (`p norm bounded by ) to xi . Therefore, by minimizing
the objective in (1), we can encourage f to be smooth within the neighborhoods of all xi ’s. Such a
smoothness-inducing property is particularly helpful to prevent overfitting and improve general-
ization on a low resource target domain for a certain task. An illustration is provided in Figure
1.
Note that the idea of measuring the local Lipschitz continuity is similar to the local shift
sensitivity criterion in existing literature on robust statistics, which dates back to 1960’s (Hampel,
1974; Huber, 2011). This criterion has been used to characterize the dependence of an estimator on
the value of one of the sample points.
(a) (b)
Figure 1: Decision boundaries learned without (a) and with (b) smoothness-inducing adversarial
regularization, respectively. The red dotted line in (b) represents the decision boundary in (a). As
can be seen, the output f in (b) does not change much within the neighborhood of training data
points.
where µ > 0 is a tuning parameter, and DBreg (·, ·) is the Bregman divergence defined as
1 Pn
DBreg (θ, θt ) = n i=1 `s (f (xi ; θ), f (xi ; θt )),
where `s is defined in Section 3.1. As can be seen, when µ is large, the Bregman divergence at
each iteration of the VBPP method essentially serves as a strong regularizer and prevents θt+1
from deviating too much from the previous iterate θt . This is also known as the trust-region
type iteration in existing optimization literature (Conn et al., 2000). Consequently, the Bregman
5
proximal point method can effectively retain the knowledge of the out-of-domain data in the pre-
trained model f (·; θ0 ). Since each subproblem (2) of VBPP does not admit a closed-form solution,
we need to solve it using SGD-type algorithms such as ADAM. Note that we do not need to solve
each subproblem until convergence. A small number of iterations are sufficient to output a reliable
initial solution for solving the next subproblem.
Moreover, the Bregman proximal point method is capable of adapting to the information ge-
ometry (See more details in Raskutti and Mukherjee (2015)) of machine learning models and
achieving better computational performance than the standard proximal point method (i.e.,
DBreg (θ, θt ) = kθ − θt k22 ) in many applications.
Acceleration by Momentum. Similar to other optimization methods in existing literature, we can
accelerate the Bregman proximal point method by introducing an additional momentum to the
update. Specifically, at the (t + 1)-th iteration, the momentum Bregman proximal point (MBPP)
method takes
6
Algorithm 1 SMART: We use the smoothness-inducing adversarial regularizer with p = ∞ and the
momentum Bregman proximal point method.
1 P
Notation: For simplicity, we denote gi (e xi , θ̄s ) = |B| xi ∈B ∇e x `s (f (xi ; θ̄s ), f (e
xi ; θ̄s )) and AdamUpdateB
denotes the ADAM update rule for optimizing (3) using the mini-batch B; ΠA denotes the
projection to A.
Input: T : the total number of iterations, X : the dataset, θ0 : the parameter of the pre-trained model,
S: the total number of iteration for solving (2), σ 2 : the variance of the random initialization
for e
xi ’s, Te
x : the number of iterations for updating e xi ’s, η: the learning rate for updating e xi ’s, β:
momentum parameter.
1: θ
e1 ← θ0
2: for t = 1, .., T do
3: θ̄1 ← θt−1
4: for s = 1, .., S do
5: Sample a mini-batch B from X
6: For all xi ∈ B, initialize e xi ← xi + νi with νi ∼ N (0, σ 2 I)
7: for m = 1, .., Te x do
gi (e
xi ,θ̄s )
8: gi ← g (e
k i xi ,θ̄s )k∞
e
9: xi ← Πke
e xi −xk∞ ≤ (e xi + ηegi )
10: end for
11: θ̄s+1 ← AdamUpdateB (θ̄s )
12: end for
13: θt ← θ̄S
14: et+1 ← (1 − β)θ̄S + β θ
θ et
15: end for
Output: θT
updates (t > 0.1T ) following Tarvainen and Valpola (2017). Lastly, we simply set S = 1, Te
x = 1 in
Algorithm 1.
7
baseline results, which are denoted by BERTReImp .
• RoBERTa (Liu et al., 2019c): This is the RoBERTaLARGE released by authors, and we present the
reported results on the GLUE dev.
• PGD, FreeAT, FreeLB (Zhu et al., 2020): They are three adversarial training approaches built on
top of the RoBERTaLARGE .
• SMART: our proposed method as described in section 3. We use both the BERTBASE model
(SMARTBERT ) and the RoBERTaLARGE model (SMARTRoBERTa ) as the pretrained model to evaluate
the effectiveness of SMART.
The main results are reported in Table 1. This table can be clustered into two groups based on
different pretrained models: the BERTBASE model (the first group) and the RoBERTaLARGE model
(the second group). The detailed discussions are as follows.
For a fair comparison, we reproduced the BERT baseline (BERTReImp ), since several results
on the GLUE development set were missed. Our reimplemented BERT baseline is even stronger
than the originally reported results in Devlin et al. (2019). For instance, the reimplemented model
obtains 84.5% (vs. 84.4%) on MNLI in-domain development in terms of accuracy. On SST-2,
BERTReImp outperforms BERT by 0.2% (92.9% vs. 92.7%) accuracy. All these results demonstrate
the fairness of our baselines.
Table 1: Main results on GLUE development set. The best result on each task produced by a single
model is in bold and “-” denotes the missed result.
Comparing with two strong baselines BERT and RoBERTa 7 , SMART, including SMARTBERT
and SMARTRoBERTa , consistently outperforms them across all 8 GLUE tasks by a big margin.
Comparing with BERT, SMARTBERT obtained 85.6% (vs. 84.5%) and 86.0% (vs. 84.4%) in terms of
accuracy, which is 1.1% and 1.6% absolute improvement, on the MNLI in-domain and out-domain
settings. Even comparing with the state-of-the-art model RoBERTa, SMARTRoBERTa improves 0.8%
7 In our experiments, we use BERT referring the BERT
BASE model, which has 110 million parameters, and RoBERTa
referring the RoBERTaLARGE model, which has 356 million parameters, unless stated otherwise.
8
Model /#Train CoLA SST MRPC STS-B QQP MNLI-m/mm QNLI RTE WNLI AX Score #param
8.5k 67k 3.7k 7k 364k 393k 108k 2.5k 634
Human Performance 66.4 97.8 86.3/80.8 92.7/92.6 59.5/80.4 92.0/92.8 91.2 93.6 95.9 - 87.1 -
Ensemble Models
RoBERTa1 67.8 96.7 92.3/89.8 92.2/91.9 74.3/90.2 90.8/90.2 98.9 88.2 89.0 48.7 88.5 356M
FreeLB2 68.0 96.8 93.1/90.8 92.4/92.2 74.8/90.3 91.1/90.7 98.8 88.7 89.0 50.1 88.8 356M
ALICE3 69.2 97.1 93.6/91.5 92.7/92.3 74.4/90.7 90.7/90.2 99.2 87.3 89.7 47.8 89.0 340M
ALBERT4 69.1 97.1 93.4/91.2 92.5/92.0 74.2/90.5 91.3/91.0 99.2 89.2 91.8 50.2 89.4 235M∗
MT-DNN-SMART† 69.5 97.5 93.7/91.6 92.9/92.5 73.9/90.2 91.0/90.8 99.2 89.7 94.5 50.2 89.9 356M
Single Model
BERTLARGE 5 60.5 94.9 89.3/85.4 87.6/86.5 72.1/89.3 86.7/85.9 92.7 70.1 65.1 39.6 80.5 335M
MT-DNN6 62.5 95.6 90.0/86.7 88.3/87.7 72.4/89.6 86.7/86.0 93.1 75.5 65.1 40.3 82.7 335M
T58 70.8 97.1 91.9/89.2 92.5/92.1 74.6/90.4 92.0/91.7 96.7 92.5 93.2 53.1 89.7 11,000M
SMARTRoBERTa 65.1 97.5 93.7/91.6 92.9/92.5 74.0/90.1 91.0/90.8 95.4 87.9 91.88 50.2 88.4 356M
Table 2: GLUE test set results scored using the GLUE evaluation server. The state-of-the-art
results are in bold. All the results were obtained from https://gluebenchmark.com/leaderboard
on December 5, 2019. SMART uses the classification objective on QNLI. Model references: 1 Liu
et al. (2019c); 2 Zhu et al. (2020); 3 Wang et al. (2019); 4 Lan et al. (2019); 5 Devlin et al. (2019);
6 Liu et al. (2019b); 7 Raffel et al. (2019) and 8 He et al. (2019), Kocijan et al. (2019). ∗ ALBERT
uses a model similar in size, architecture and computation cost to a 3,000M BERT (though it has
dramatically fewer parameters due to parameter sharing). † Mixed results from ensemble and
single of MT-DNN-SMART and with data augmentation.
(91.1% vs. 90.2%) on MNLI in-domain development set. Interestingly, on the MNLI task, the
performance of SMART on the out-domain setting is better than the in-domain setting, e.g., (86.0%
vs. 85.6%) by SMARTBERT and (91.3% vs. 91.1%) by SMARTRoBERTa , showing that our proposed
approach alleviates the domain shifting issue. Furthermore, on the small tasks, the improvement of
SMART is even larger. For example, comparing with BERT, SMARTBERT obtains 71.2% (vs. 63.5%)
on RTE and 59.1% (vs. 54.7%) on CoLA in terms of accuracy, which are 7.7% and 4.4% absolute
improvement for RTE and CoLA, respectively; similarly, SMARTRoBERTa outperforms RoBERTa
5.4% (92.0% vs. 86.6%) on RTE and 2.6% (70.6% vs. 68.0%) on CoLA.
We also compare SMART with a range of models which used adversarial training such as
FreeLB. From the bottom rows in Table 1, SMART outperforms PGD and FreeAT across the all
8 GLUE tasks. Comparing with the current state-of-the-art adversarial training model, FreeLB,
SMART outperforms it on 6 GLUE tasks out of a total of 8 tasks (MNLI, RTE, QNLI, MRPC, SST-2
and STS-B) showing the effectiveness of our model.
Table 2 summarizes the current state-of-the-art models on the GLUE leaderboard. SMART
obtains a competitive result comparing with T5 (Raffel et al., 2019), which is the leading model at
the GLUE leaderboard. T5 has 11 billion parameters, while SMART only has 356 millions. Among
this super large model (T5) and other ensemble models (e.g., ALBERT, ALICE), SMART, which is a
single model, still sets new state-of-the-art results on SST-2, MRPC and STS-B. By combining with
the Multi-task Learning framework (MT-DNN), MT-DNN-SMART obtains new state-of-the-art on
GLUE, pushing the GLUE benchmark to 89.9%. More discussion will be provided in Section 5.3.
9
5 Experiment – Analysis and Extension
In this section, we first analyze the effectiveness of each component of the proposed method. We
also study that whether the proposed method is complimentary to multi-task learning. We further
extend SMART to domain adaptation and use both SNLI (Bowman et al., 2015) and SciTail (Khot
et al., 2018) to evaluate the effectiveness. Finally, we verified the robustness of the proposed method
on ANLI (Nie et al., 2019).
Table 3: Ablation study of SMART on 5 GLUE tasks. Note that all models used the BERTBASE model
as their encoder.
The results are reported in Table 3. It is expected that the removal of either component (smooth
regularization or proximal point method) in SMART would result in a performance drop. For
example, on MNLI, removing smooth regularization leads to a 0.8% (85.6% vs. 84.8) performance
drop, while removing the Breg proximal point optimization, results in a performance drop of
0.2% (85.6% vs. 85.4%). It demonstrates that these two components complement each other.
Interestingly, all three proposed models outperform the BERT baseline model demonstrating the
effectiveness of each module. Moreover, we obersere that the generalization performance benefits
more from SMART on small datasets (i.e., RTE and MRPC) by preventing overfitting.
10
same and the other two annotations are the same; 4) 3/1/1 three annotations are the same and the
other two annotations are different.
Figure 2 summarizes the results in terms of both accuracy and KL-divergence:
n 3
1 XX
− pj (xi ) log(fj (xi )).
n
i=1 j=1
For a given sample xi , the KL-Divergence evaluates the similarity between the model prediction
{fj (xi )}3j=1 and the annotation distribution {pj (xi )}3j=1 . We observe that SMARTRoBERTa outperforms
RoBERTa across all the settings. Further, on high degree of ambiguity (low degree of agreement),
SMARTRoBERTa obtains an even larger improvement showing its robustness to ambiguity.
85
Accuracy
80
76.1
75
72.1
70.7
69.7
70 68.067.4
65.4
65 63.4
60
1.50 1.12
1.15
1.25 0.97
0.94
KL-Divergence
0.41 0.43
0.50 0.36 0.35
0.26 0.25
0.25
0.080.08 0.080.08
0.00
All 5/0/0 4/1/0 3/2/0 3/1/1 All 5/0/0 4/1/0 3/2/0 3/1/1
11
SMART 8 , and then adapted the training data on each task on top of the shared embeddings. We
also include a baseline which fine-tuned each task on the publicly released MT-DNN checkpoint 9 ,
which is indicated as MT-DNN-SMARTv0 .
We observe that both MT-DNN and SMART consistently outperform the BERT model on all
five GLUE tasks. Furthermore, SMART outperforms MT-DNN on MNLI, QNLI, and MRPC, while
it obtains worse results on RTE and SST, showing that MT-DNN is a strong counterpart for SMART.
By combining these two models, MT-DNN-SMARTv0 enjoys advantages of both and thus improved
the final results. For example, it achieves 85.7% (+0.1%) on MNLI and 80.2% (+1.1%) on RTE
comparing with the best results of MT-DNN and SMART demonstrating that these two techniques
are orthogonal. Lastly we also trained SMART jointly and then finetuned on each task like Liu
et al. (2019b). We observe that MT-DNN-SMART outperformes MT-DNN-SMARTv0 and MT-DNN
across all 5 tasks (except MT-DNN on SST) showing that SMART improves the generalization of
MTL.
MRPC, while MT-DNN was trained on the whole GLUE tasks except CoLA.
9 It is from: https://github.com/namisan/mt-dnn. Note that we did not use the complicated answer module, e.g.,
12
Model 0.1% 1% 10% 100%
SNLI Dataset (Dev Accuracy%)
#Training Data 549 5,493 54,936 549,367
BERT 52.5 78.1 86.7 91.0
MT-DNN 82.1 85.2 88.4 91.5
MT-DNN-SMART 82.7 86.0 88.7 91.6
SciTail Dataset (Dev Accuracy%)
#Training Data 23 235 2,359 23,596
BERT 51.2 82.2 90.5 94.3
MT-DNN 81.9 88.3 91.1 95.8
MT-DNN-SMART 82.3 88.6 91.3 96.1
Dev Test
Method
R1 R2 R3 All R1 R2 R3 All
MNLI + SNLI + ANLI + FEVER
BERTLARGE (Nie et al., 2019) - - - - 57.4 48.3 43.5 49.3
XLNetLARGE (Nie et al., 2019) - - - - 67.6 50.7 48.3 55.1
RoBERTaLARGE (Nie et al., 2019) - - - - 73.8 48.9 44.4 53.7
SMARTRoBERTa-LARGE 74.5 50.9 47.6 57.1 72.4 49.8 50.3 57.1
ANLI
RoBERTaLARGE (Nie et al., 2019) - - - - 71.3 43.3 43.0 51.9
SMARTRoBERTa-LARGE 74.2 49.5 49.2 57.1 72.4 50.3 49.5 56.9
outperforms the BERTLARGE model. Similar observation is found on SciTail and in the BERTLARGE
model setting. We see that incorporating SMART into MT-DNN achieves new state-of-the-art
results on both SNLI and SciTail, pushing benchmarks to 91.7% on SNLI and 95.2% on SciTail.
5.6 Robustness
One important property of the machine learning model is its robustness to adversarial attack. We
test our model on an adversarial natural language inference (ANLI) dataset Nie et al. (2019).
We evaluate the performance of SMART on each subset (i.e., R1,R2,R3) of ANLI dev and test
set. The results are presented in Table 6. Table 6 shows the results of training on combined NLI
data: ANLI (Nie et al., 2019) + MNLI (Williams et al., 2018) + SNLI Bowman et al. (2015) + FEVER
(Thorne et al., 2018) and training on only ANLI data. In the combined data setting, we obverse that
SMARTRoBERTa-LARGE obtains the best performance compared with all the strong baselines, pushing
benchmarks to 57.1%. In case of the RoBERTaLARGE baseline, SMARTRoBERTa-LARGE outperforms
13
Model Dev Test
SNLI Dataset (Accuracy%)
BERTBASE 91.0 90.8
BERTBASE +SRL(Zhang et al., 2018) - 90.3
MT-DNNBASE 91.4 91.1
SMARTBERT-BASE 91.4 91.1
MT-DNN-SMARTBASEv0 91.7 91.4
MT-DNN-SMARTBASE 91.7 91.5
BERTLARGE +SRL(Zhang et al., 2018) - 91.3
BERTLARGE 91.7 91.0
MT-DNNLARGE 92.2 91.6
MT-DNN-SMARTLARGEv0 92.6 91.7
SciTail Dataset (Accuracy%)
GPT (Radford et al., 2018) - 88.3
BERTBASE 94.3 92.0
MT-DNNBASE 95.8 94.1
SMARTBERT-BASE 94.8 93.2
MT-DNN-SMARTBASEv0 96.0 94.0
MT-DNN-SMARTBASE 96.1 94.2
BERTLARGE 95.7 94.4
MT-DNNLARGE 96.3 95.0
SMARTBERT-LARGE 96.2 94.7
MT-DNN-SMARTLARGEv0 96.6 95.2
3.4% absolute improvement on dev and 7.4% absolute improvement on test, indicating the robust-
ness of SMART. We obverse that in the ANLI-only setting, SMARTRoBERTa-LARGE outperforms the
strong RoBERTaLARGE baseline with a large margin, +5.2% (57.1% vs. 51.9%)
6 Conclusion
We propose a robust and efficient computation framework, SMART, for fine-tuning large scale
pre-trained natural language models in a principled manner. The framework effectively alleviates
the overfitting and aggressive updating issues in the fine-tuning stage. SMART includes two
important ingredients: 1) smooth-inducing adversarial regularization; 2) Bregman proximal point
optimization. Our empirical results suggest that SMART improves the performance on many NLP
benchmarks (e.g., GLUE, SNLI, SciTail, ANLI) with the state-of-the-art pre-trained models (e.g.,
BERT, MT-DNN, RoBERTa). We also demonstrate that the proposed framework is applicable to
domain adaptation and results in a significant performance improvement. Our proposed fine-
tuning framework can be generalized to solve other transfer learning problems. We will explore
14
this direction as future work.
Acknowledgments
We thank Jade Huang, Niao He, Chris Meek, Liyuan Liu, Yangfeng Ji, Pengchuan Zhang, Oleksandr
Polozov, Chenguang Zhu and Keivn Duh for valuable discussions and comments, and Microsoft
Research Technology Engineering team for setting up GPU machines. We also thank the anonymous
reviewers for valuable discussions.
References
Bar-Haim, R., Dagan, I., Dolan, B., Ferro, L. and Giampiccolo, D. (2006). The second PAS-
CAL recognising textual entailment challenge. In Proceedings of the Second PASCAL Challenges
Workshop on Recognising Textual Entailment.
Bentivogli, L., Dagan, I., Dang, H. T., Giampiccolo, D. and Magnini, B. (2009). The fifth pascal
recognizing textual entailment challenge. In In Proc Text Analysis Conference (TAC’09.
Bowman, S., Angeli, G., Potts, C. and Manning, C. D. (2015). A large annotated corpus for
learning natural language inference. In Proceedings of the 2015 Conference on Empirical Methods
in Natural Language Processing.
Cer, D., Diab, M., Agirre, E., Lopez-Gazpio, I. and Specia, L. (2017). Semeval-2017 task 1: Semantic
textual similarity multilingual and crosslingual focused evaluation. In Proceedings of the 11th
International Workshop on Semantic Evaluation (SemEval-2017).
Conn, A. R., Gould, N. I. and Toint, P. L. (2000). Trust region methods, vol. 1. Siam.
Dagan, I., Glickman, O. and Magnini, B. (2006). The pascal recognising textual entailment
challenge. In Proceedings of the First International Conference on Machine Learning Challenges:
Evaluating Predictive Uncertainty Visual Object Classification, and Recognizing Textual Entailment.
MLCW’05, Springer-Verlag, Berlin, Heidelberg.
http://dx.doi.org/10.1007/11736790_9
Devlin, J., Chang, M.-W., Lee, K. and Toutanova, K. (2019). Bert: Pre-training of deep bidirectional
transformers for language understanding. In Proceedings of the 2019 Conference of the North
American Chapter of the Association for Computational Linguistics: Human Language Technologies,
Volume 1 (Long and Short Papers).
Dong, L., Yang, N., Wang, W., Wei, F., Liu, X., Wang, Y., Gao, J., Zhou, M. and Hon, H.-W.
(2019). Unified language model pre-training for natural language understanding and generation
13042–13054.
15
Eckstein, J. (1993). Nonlinear proximal point algorithms using bregman functions, with applica-
tions to convex programming. Mathematics of Operations Research, 18 202–226.
Giampiccolo, D., Magnini, B., Dagan, I. and Dolan, B. (2007). The third PASCAL recognizing
textual entailment challenge. In Proceedings of the ACL-PASCAL Workshop on Textual Entailment
and Paraphrasing. Association for Computational Linguistics, Prague.
https://www.aclweb.org/anthology/W07-1401
Güler, O. (1991). On the convergence of the proximal point algorithm for convex minimization.
SIAM Journal on Control and Optimization, 29 403–419.
Güler, O. (1992). New proximal point algorithms for convex minimization. SIAM Journal on
Optimization, 2 649–664.
Hampel, F. R. (1974). The influence curve and its role in robust estimation. Journal of the american
statistical association, 69 383–393.
He, P., Liu, X., Chen, W. and Gao, J. (2019). A hybrid neural network model for commonsense
reasoning. In Proceedings of the First Workshop on Commonsense Inference in Natural Language
Processing.
Houlsby, N., Giurgiu, A., Jastrzebski, S., Morrone, B., De Laroussilhe, Q., Gesmundo, A.,
Attariyan, M. and Gelly, S. (2019). Parameter-efficient transfer learning for nlp. In International
Conference on Machine Learning.
Howard, J. and Ruder, S. (2018). Universal language model fine-tuning for text classification. In
Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1:
Long Papers).
Joshi, M., Chen, D., Liu, Y., Weld, D. S., Zettlemoyer, L. and Levy, O. (2020). Spanbert: Im-
proving pre-training by representing and predicting spans. Transactions of the Association for
Computational Linguistics, 8 64–77.
Khot, T., Sabharwal, A. and Clark, P. (2018). SciTail: A textual entailment dataset from science
question answering. In AAAI.
Kingma, D. and Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint
arXiv:1412.6980.
Kocijan, V., Cretu, A.-M., Camburu, O.-M., Yordanov, Y. and Lukasiewicz, T. (2019). A surpris-
ingly robust trick for the winograd schema challenge. In Proceedings of the 57th Annual Meeting
of the Association for Computational Linguistics.
Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma, P. and Soricut, R. (2019). Albert: A lite bert
for self-supervised learning of language representations. arXiv preprint arXiv:1909.11942.
16
Levesque, H., Davis, E. and Morgenstern, L. (2012). The winograd schema challenge. In Thirteenth
International Conference on the Principles of Knowledge Representation and Reasoning.
Liu, L., Jiang, H., He, P., Chen, W., Liu, X., Gao, J. and Han, J. (2020a). On the variance of
the adaptive learning rate and beyond. In Proceedings of the Eighth International Conference on
Learning Representations (ICLR 2020).
Liu, X., Duh, K. and Gao, J. (2018). Stochastic answer networks for natural language inference.
arXiv preprint arXiv:1804.07888.
Liu, X., Gao, J., He, X., Deng, L., Duh, K. and Wang, Y.-Y. (2015). Representation learning using
multi-task deep neural networks for semantic classification and information retrieval. In Pro-
ceedings of the 2015 Conference of the North American Chapter of the Association for Computational
Linguistics: Human Language Technologies.
Liu, X., He, P., Chen, W. and Gao, J. (2019a). Improving multi-task deep neural networks via
knowledge distillation for natural language understanding. arXiv preprint arXiv:1904.09482.
Liu, X., He, P., Chen, W. and Gao, J. (2019b). Multi-task deep neural networks for natural language
understanding. In Proceedings of the 57th Annual Meeting of the Association for Computational
Linguistics. Association for Computational Linguistics, Florence, Italy.
https://www.aclweb.org/anthology/P19-1441
Liu, X., Wang, Y., Ji, J., Cheng, H., Zhu, X., Awa, E., He, P., Chen, W., Poon, H., Cao, G. and
Gao, J. (2020b). The microsoft toolkit of multi-task deep neural networks for natural language
understanding. arXiv preprint arXiv:2002.07972.
Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., Levy, O., Lewis, M., Zettlemoyer, L. and
Stoyanov, V. (2019c). Roberta: A robustly optimized bert pretraining approach. arXiv preprint
arXiv:1907.11692.
Miyato, T., Maeda, S.-i., Koyama, M. and Ishii, S. (2018). Virtual adversarial training: a regulariza-
tion method for supervised and semi-supervised learning. IEEE transactions on pattern analysis
and machine intelligence, 41 1979–1993.
Mohri, M., Rostamizadeh, A. and Talwalkar, A. (2018). Foundations of machine learning. MIT
press.
Nie, Y., Williams, A., Dinan, E., Bansal, M., Weston, J. and Kiela, D. (2019). Adversarial nli: A
new benchmark for natural language understanding. arXiv preprint arXiv:1910.14599.
Pan, S. J. and Yang, Q. (2009). A survey on transfer learning. IEEE Transactions on knowledge and
data engineering, 22 1345–1359.
Parikh, N., Boyd, S. et al. (2014). Proximal algorithms. Foundations and Trends® in Optimization, 1
127–239.
17
Peters, M. E., Neumann, M., Iyyer, M., Gardner, M., Clark, C., Lee, K. and Zettlemoyer, L.
(2018). Deep contextualized word representations. In Proceedings of NAACL-HLT.
Peters, M. E., Ruder, S. and Smith, N. A. (2019). To tune or not to tune? adapting pretrained
representations to diverse tasks. ACL 2019 7.
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D. and Sutskever, I. (2018). Language models
are unsupervised multitask learners.
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D. and Sutskever, I. (2019). Language models
are unsupervised multitask learners. OpenAI Blog, 1.
Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W. and Liu, P. J.
(2019). Exploring the limits of transfer learning with a unified text-to-text transformer. arXiv
preprint arXiv:1910.10683.
Rajpurkar, P., Zhang, J., Lopyrev, K. and Liang, P. (2016). SQuAD: 100,000+ questions for machine
comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural
Language Processing. Association for Computational Linguistics, Austin, Texas.
https://www.aclweb.org/anthology/D16-1264
Raskutti, G. and Mukherjee, S. (2015). The information geometry of mirror descent. IEEE
Transactions on Information Theory, 61 1451–1457.
Rockafellar, R. T. (1976). Monotone operators and the proximal point algorithm. SIAM journal on
control and optimization, 14 877–898.
Shu, R., Bui, H. H., Narui, H. and Ermon, S. (2018). A dirt-t approach to unsupervised domain
adaptation. arXiv preprint arXiv:1802.08735.
Socher, R., Perelygin, A., Wu, J., Chuang, J., Manning, C. D., Ng, A. and Potts, C. (2013).
Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings
of the 2013 conference on empirical methods in natural language processing.
Stickland, A. C. and Murray, I. (2019). Bert and pals: Projected attention layers for efficient
adaptation in multi-task learning. In International Conference on Machine Learning.
Tarvainen, A. and Valpola, H. (2017). Mean teachers are better role models: Weight-averaged con-
sistency targets improve semi-supervised deep learning results. In Advances in neural information
processing systems.
Thorne, J., Vlachos, A., Christodoulopoulos, C. and Mittal, A. (2018). Fever: a large-scale
dataset for fact extraction and verification. arXiv preprint arXiv:1803.05355.
18
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł. and
Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing
systems.
Wang, A., Singh, A., Michael, J., Hill, F., Levy, O. and Bowman, S. R. (2018). Glue: A multi-task
benchmark and analysis platform for natural language understanding. EMNLP 2018 353.
Wang, W., Bi, B., Yan, M., Wu, C., Bao, Z., Peng, L. and Si, L. (2019). Structbert: Incorporat-
ing language structures into pre-training for deep language understanding. arXiv preprint
arXiv:1908.04577.
Warstadt, A., Singh, A. and Bowman, S. R. (2019). Neural network acceptability judgments.
Transactions of the Association for Computational Linguistics, 7 625–641.
Williams, A., Nangia, N. and Bowman, S. (2018). A broad-coverage challenge corpus for sentence
understanding through inference. In Proceedings of the 2018 Conference of the North American
Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1
(Long Papers). Association for Computational Linguistics.
http://aclweb.org/anthology/N18-1101
Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., Cistac, P., Rault, T., Louf,
R., Funtowicz, M. and Brew, J. (2019). Huggingface’s transformers: State-of-the-art natural
language processing. ArXiv, abs/1910.03771.
Yang, Z., Dai, Z., Yang, Y., Carbonell, J., Salakhutdinov, R. R. and Le, Q. V. (2019). Xlnet:
Generalized autoregressive pretraining for language understanding. In Advances in neural
information processing systems.
Zhang, H., Yu, Y., Jiao, J., Xing, E., El Ghaoui, L. and Jordan, M. (2019). Theoretically principled
trade-off between robustness and accuracy. In International Conference on Machine Learning.
Zhang, Z., Wu, Y., Li, Z., He, S. and Zhao, H. (2018). I know what you want: Semantic learning for
text comprehension.
Zhu, C., Cheng, Y., Gan, Z., Sun, S., Goldstein, T. and Liu, J. (2020). Freelb: Enhanced adversarial
training for natural language understanding.
https://openreview.net/forum?id=BygzbyHFvB
19
7 Datasets
Table 8: Summary of the four benchmarks: GLUE, SNLI, SciTail and ANLI.
The GLUE benchmark, SNLI, SciTail and ANLI is briefly introduced in the following sections.
The detailed description can be found in Wang et al. (2018); Bowman et al. (2015); Khot et al.
(2018); Nie et al. (2019). Table 8 summarizes the information of these tasks.
• GLUE. The General Language Understanding Evaluation (GLUE) benchmark is a collection of
nine natural language understanding (NLU) tasks. As shown in Table 8, it includes question answer-
ing Rajpurkar et al. (2016), linguistic acceptability Warstadt et al. (2019), sentiment analysis Socher
et al. (2013), text similarity Cer et al. (2017), paraphrase detection Dolan and Brockett (2005), and
natural language inference (NLI) Dagan et al. (2006); Bar-Haim et al. (2006); Giampiccolo et al.
(2007); Bentivogli et al. (2009); Levesque et al. (2012); Williams et al. (2018). The diversity of the
tasks makes GLUE very suitable for evaluating the generalization and robustness of NLU models.
• SNLI. The Stanford Natural Language Inference (SNLI) dataset contains 570k human annotated
sentence pairs, in which the premises are drawn from the captions of the Flickr30 corpus and
hypotheses are manually annotated Bowman et al. (2015). This is the most widely used entailment
dataset for NLI. The dataset is used only for domain adaptation in this study.
• SciTail This is a textual entailment dataset derived from a science question answering (SciQ)
dataset Khot et al. (2018). The task involves assessing whether a given premise entails a given
hypothesis. In contrast to other entailment datasets mentioned previously, the hypotheses in
SciTail are created from science questions while the corresponding answer candidates and premises
20
come from relevant web sentences retrieved from a large corpus. As a result, these sentences are
linguistically challenging and the lexical similarity of premise and hypothesis is often high, thus
making SciTail particularly difficult. The dataset is used only for domain adaptation in this study.
• ANLI. The Adversarial Natural Language Inference (ANLI, Nie et al. (2019)) is a new large-
scale NLI benchmark dataset, collected via an iterative, adversarial human-and-model-in-the-loop
procedure. Particular, the data is selected to be difficult to the state-of-the-art models, including
BERT and RoBERTa.
8 Hyperparameters
As for the sensitivities of hyper-parameters, in general the performance of our method is not very
sensitive to the choice of hyper-parameters as detailed below.
• We only observed slight differences in model performance when λs ∈ [1, 10], µ ∈ [1, 10] and
∈ [10−5 , 10−4 ]. When λs ≥ 100, µ ≥ 100 or ≥ 10−3 , the regularization is unreasonably
strong. When λs ≤ 0.1, µ ≤ 0.1 or <= 10−6 , the regularization is unreasonably weak.
• p = ∞ makes the size of perturbation constraint to be the same regardless of the number of
dimensions. For p = 2, adversarial perturbation is sensitive to the number of dimensions
(A higher dimension usually requires a larger perturbation), especially for sentences with
different length. As a result, we need to make less tuning effort for p = ∞. For other values of
p, the associated projections are computationally inefficient.
• We set β = 0.99 for the first 10% of the updates (t <= 0.1T ) and β = 0.999 for the rest of the
updates (t > 0.1T ) following (Tarvainen and Valpola, 2017), which works well in practice.
21