Extracted Text
2001.04451v2.pdf
Published as a conference paper at ICLR 2020
REFORMER: THEEFFICIENTTRANSFORMER
Nikita Kitaev
U.C. Berkeley & Google Research
kitaev@cs.berkeley.edu
ukasz Kaiser
Google Research
flukaszkaiser,levskaya g@google.com
Anselm Levskaya
Google Research
ABSTRACT
Large Transformer models routinely achieve state-of-the-art results on a number
of tasks but training these models can be prohibitively costly, especially on long
sequences. We introduce two techniques to improve the efciency of Transform-
ers. For one, we replace dot-product attention by one that uses locality-sensitive
hashing, changing its complexity from O(L
2
) to O(LlogL), whereLis the length
of the sequence. Furthermore, we use reversible residual layers instead of the
standard residuals, which allows storing activations only once in the training pro-
cess instead ofNtimes, whereNis the number of layers. The resulting model,
the Reformer, performs on par with Transformer models while being much more
memory-efcient and much faster on long sequences.
1 INTRODUCTION
The Transformer architecture (Vaswani et al., 2017) is widely used in natural language processing
and yields state-of-the-art results on a number of tasks. To obtain these results, researchers have
resorted to training ever larger Transformer models. The number of parameters exceeds 0.5B per
layer in the largest conguration reported in (Shazeer et al., 2018) while the number of layers goes
up to 64 in (Al-Rfou et al., 2018). Transformer models are also used on increasingly long sequences.
Up to 11 thousand tokens of text in a single example were processed in (Liu et al., 2018) and when
processing other modalities, like music (Huang et al., 2018) and images (Parmar et al., 2018), even
longer sequences are commonplace. These large-scale long-sequence models yield great results but
strain resources to the point where some argue that this trend is breaking NLP research
1
. Many
large Transformer models can only realistically be trained in large industrial research laboratories
and such models trained with model parallelism cannot even be ne-tuned on a single GPU as their
memory requirements demand a multi-accelerator hardware setup even for a single training step.
Do large Transformer models fundamentally require such huge resources or are they simply inef-
cient? Consider the following calculation: the 0.5B parameters used in the largest reported Trans-
former layer account for 2GB of memory. Activations for 64K tokens with embedding size 1024
and batch size 8 account for64K1K8 = 0:5B oats, requiring another 2GB of memory. If
our memory use was only per-layer, then we should fairly easily t a large Transformer even on
sequences of length 64K on a single accelerator. Further, the whole corpus used to train BERT
only requires 17GB to store. Why is it then that we cannot even ne-tune these models on single
machines?
The above estimate includes only per-layer memory and input activations cost and does not take into
account the following major sources of memory use in the Transformer.
Memory in a model withNlayers isN-times larger than in a single-layer model due to the
fact that activations need to be stored for back-propagation.
Since the depthdf fof intermediate feed-forward layers is often much larger than the depth
dmodelof attention activations, it accounts for a large fraction of memory use.
Attention on sequences of lengthLis O(L
2
) in both computational and memory complex-
ity, so even for a single sequence of64K tokens can exhaust accelerator memory.
Equal Contribution
1
https://hackingsemantics.xyz/2019/leaderboards/
1
Published as a conference paper at ICLR 2020
We introduce the Reformer model which solves these problems using the following techniques:
Reversible layers, rst introduced in Gomez et al. (2017), enable storing only a single copy
of activations in the whole model, so theNfactor disappears.
Splitting activations inside feed-forward layers and processing them in chunks removes the
df ffactor and saves memory inside feed-forward layers.
Approximate attention computation based on locality-sensitive hashing replaces the O(L
2
)
factor in attention layers with O(LlogL) and so allows operating on long sequences.
We study these techniques and show that they have negligible impact on the training process com-
pared to the standard Transformer. Splitting activations in fact only affects the implementation; it is
numerically identical to the layers used in the Transformer. Applying reversible residuals instead of
the standard ones does change the model but has a negligible effect on training in all congurations
we experimented with. Finally, locality-sensitive hashing in attention is a more major change that
can inuence the training dynamics, depending on the number of concurrent hashes used. We study
this parameter and nd a value which is both efcient to use and yields results very close to full
attention.
We experiment on a synthetic task, a text task (enwik8) with sequences of length 64K and an image
generation task (imagenet-64 generation) with sequences of length 12K. In both cases we show that
Reformer matches the results obtained with full Transformer but runs much faster, especially on the
text task, and with orders of magnitude better memory efciency.
2 LOCALITY-SENSITIVEHASHINGATTENTION
Dot-product attention.The standard attention used in the Transformer is the scaled dot-product
attention (Vaswani et al., 2017). The input consists of queries and keys of dimensiondk, and values
of dimensiondv. The dot products of the query with all keys are computed, scaled by
p
dk, and a
softmax function is applied to obtain the weights on the values. In practice, the attention function
on a set of queries is computed simultaneously, packed together into a matrixQ. Assuming the keys
and values are also packed together into matricesKandV, the matrix of outputs is dened as:
Attention(Q; K; V) = softmax(
QK
T
p
dk
)V (1)
Multi-head attention.In the Transformer, instead of performing a single attention function with
dmodel-dimensional keys, values and queries, one linearly projects the queries, keys and valuesh
times with different, learned linear projections todk,dkanddvdimensions, respectively. Attention
is applied to each of these projected versions of queries, keys and values in parallel, yieldingdv-
dimensional output values. These are concatenated and once again projected, resulting in the nal
values. This mechanism is known as multi-head attention.
Memory-efcient attention.To calculate the memory use of the attention mechanism, let us
focus on the attention computation from Equation 1. Let us assume that Q, K and V all have
the shape[batchsize; length; dmodel]. The main issue is the termQK
T
, which has the shape
[batchsize; length; length]. In the experimental section we train a model on sequences of length
64K in this case, even at batch-size of 1, this is a64K64Kmatrix, which in 32-bit oats would
take 16GB of memory. This is impractical and has hindered the use of the Transformer for long
sequences. But it is important to note that theQK
T
matrix does not need to be fully materialized
in memory. The attention can indeed be computed for each queryqiseparately, only calculating
softmax(
qiK
T
p
dk
)Vonce in memory, and then re-computing it on the backward pass when needed for
gradients. This way of computing attention may be less efcient but it only uses memory propor-
tional tolength. We use this memory-efcient implementation of attention to run the full-attention
baselines presented in the experimental section.
Where do Q, K, V come from?The multi-head attention described above operates on keys,
queries and values, but usually we are only given a single tensor of activations A of the shape
[batchsize; length; dmodel] e.g., coming from embedding the tokens in a sentence into vectors.
2
Published as a conference paper at ICLR 2020
Figure 1: An angular locality sensitive hash uses random rotations of spherically projected points to
establish buckets by an argmax over signed axes projections. In this highly simplied 2D depiction,
two pointsxandyare unlikely to share the same hash buckets (above) for the three different angular
hashes unless their spherical projections are close to one another (below).
To build Q, K and V from A, the Transformer uses 3 different linear layers projecting A into Q, K
and V with different parameters. For models with LSH attention, we want queries and keys (Q and
K) to be identical. This is easily achieved by using the same linear layer to go from A to Q and
K, and a separate one for V. We call a model that behaves like this a shared-QK Transformer. It
turns out that sharing QK does not affect the performance of Transformer, even if we additionally
normalize the length of the keys K, as we show in the experimental Section 5.
Hashing attention.For the LSH attention, we start with two tensors, Q=K and V of the shape
[batchsize; length; dmodel]. We keep the multi-head mechanism intact and focus on the atten-
tion computation from Equation 1. As already mentioned, the main issue is the termQK
T
,
which has the shape[batchsize; length; length]. But note that we are actually only interested
insoftmax(QK
T
). Since softmax is dominated by the largest elements, for each queryqiwe only
need to focus on the keys in K that are closest toqi. For example, if K is of length 64K, for eachqi
we could only consider a small subset of, say, the32or64closest keys. That is much more efcient,
but how can we nd the nearest neighbors among the keys?
Locality sensitive hashing.The problem of nding nearest neighbors quickly in high-dimensional
spaces can be solved by locality-sensitive hashing (LSH). A hashing scheme that assigns each vector
xto a hashh(x)is called locality-sensitive if nearby vectors get the same hash with high probability
and distant ones do not. In our case, we actually only require that nearby vectors get the same hash
with high probability and that hash-buckets are of similar size with high probability.
We achieve this by employing random projections as follows (see Figure 1). To getbhashes, we
rst x a random matrixRof size[dk; b=2]. We then deneh(x) = arg max([xR; xR])where
[u;v]denotes the concatenation of two vectors. This method is a known LSH scheme (Andoni et al.,
2015) and is easy to implement and apply to batches of vectors.
LSH attention.Knowing our LSH scheme and the general idea of hashing attention, we will now
formalize the LSH attention we use in this paper. We rst rewrite the equation for normal attention,
(1), for a single query positioniat a time:
oi=
X
j2Pi
exp (qikj z(i;Pi))vj wherePi=fj:ijg (2)
We introduce the notationPito represent the set that the query at positioniattends to, andzto
denote the partition function (i.e. the normalizing term in the softmax). For clarity, we also omit
scaling by
p
dk.
For batching purposes we typically perform attention over a larger set
e
Pi=f0;1; : : : ; lg Pi
while masking out elements not inPi:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vjwherem(j;Pi) =
1ifj =2 Pi
0otherwise
(3)
3
Published as a conference paper at ICLR 2020
Figure 2: Simplied depiction of LSH Attention showing the hash-bucketing, sorting, and chunking
steps and the resulting causal attentions. (a-d) Attention matrices for these varieties of attention.
Now we turn to LSH attention, which we can think of in terms of restricting the setPiof target
items a query positionican attend to, by only allowing attention within a single hash bucket.
Pi=fj:h(qi) =h(kj)g (4)
Figure 2(a-b) shows a schematic comparison of full-attention with a hashed variant. Part (a) depicts
that the attention matrix for full attention is typically sparse, but the computation does not take
advantage of this sparsity. In (b), the queries and keys have been sorted according to their hash
bucket. Since similar items fall in the same bucket with high probability, the full attention pattern
can be approximated by only allowing attention within each bucket.
Hash buckets in this formulation tend to be uneven in size, which makes it difcult to batch across
buckets. Moreover, the number of queries and the number of keys within a bucket may be unequal
in fact, it is possible for a bucket to contain many queries but no keys. To alleviate these issues, we
rst ensure thath(kj) =h(qj)by settingkj=
qj
kqjk
. Next, we sort the queries by bucket number
and, within each bucket, by sequence position; this denes a permutation wherei7!siafter sorting.
In the sorted attention matrix, pairs from the same bucket will cluster near the diagonal (as depicted
in Figure 2c). We can follow a batching approach where chunks ofmconsecutive queries (after
sorting) attend to each other, and one chunk back (Figure 2d). Following our earlier notation, this
corresponds to setting:
e
Pi=
n
j:
j
si
m
k
1
j
sj
m
k
j
si
m
ko
(5)
IfmaxijPij< m, thenPi
e
Pi. In practice we setm=
2l
nbuckets
(wherelis the sequence length).
The average bucket size is
l
nbuckets
, and we assume that the probability of a bucket growing to twice
that size is sufciently low. The overall process of LSH attention is summarized in Figure 2.
Multi-round LSH attention.With hashing, there is always a small probability that similar items
nevertheless fall in different buckets. This probability can be reduced by doing multiple rounds of
hashing withnroundsdistinct hash functionsfh
(1)
; h
(2)
; : : :g, such that:
Pi=
nrounds[
r=1
P
(r)
i
whereP
(r)
i
=
n
j:h
(r)
(qi) =h
(r)
(qj)
o
(6)
The multi-round case essentially involves performing LSH attentionnroundstimes in parallel; the
details of the procedure are described in in Appendix A.
Causal masking for shared-QK attention.In a Transformer decoder, masking (denoted by
m(j;Pi)in Equation 3) is used to prevent positions from attending into the future. To implement
masking in LSH attention, we associate every query/key vector with a position index, re-order the
position indices using the same permutations used to sort the query/key vectors, and then use a
comparison operation to compute the mask.
4
Published as a conference paper at ICLR 2020
Table 1: Memory and time complexity of attention variants. We writelfor length,bfor batch size,
nhfor the number of heads,ncfor the number of LSH chunks,nrfor the number of hash repetitions.
Attention Type Memory Complexity Time Complexity
Scaled Dot-Product max(bnhldk; bnhl
2
) max( bnhldk; bnhl
2
)
Memory-Efcient max(bnhldk; bnhl
2
) max( bnhldk; bnhl
2
)
LSH Attention max(bnhldk; bnhlnr(4l=nc)
2
) max(bnhldk; bnhnrl(4l=nc)
2
)
Table 2: Accuracies on the duplication task of a 1-layer Transformer model with full attention and
with locality-sensitive hashing attention using different number of parallel hashes.
Train
Eval
Full AttentionLSH-8LSH-4LSH-2LSH-1
Full Attention100% 94.8%92.5%76.9%52.5%
LSH-4 0.8% 100%99.9%99.4%91.9%
LSH-2 0.8% 100%99.9%98.1%86.8%
LSH-1 0.8% 99.9%99.6%94.8%77.9%
While attention to the future is not allowed, typical implementations of the Transformerdoallow
a position to attend toitself. Such behavior is undesirable in a shared-QK formulation because the
dot-product of a query vector with itself will almost always be greater than the dot product of a
query vector with a vector at another position. We therefore modify the masking to forbid a token
from attending to itself, except in situations where a token has no other valid attention targets (e.g.
the rst token in a sequence).
2.1 ANALYSIS ON A SYNTHETIC TASK
To verify the performance of LSH attention and study its behavior, we start with the following
synthetic task: duplicate a sequence of symbols. In this task, each training and testing example has
the form0w0wwherew2 f1; : : : ; Ng
is a sequence of symbols ranging from1toN(we use
N= 127in our experiments). An example with the wordwof length3is given below.
Example:0191137201911372
To study LSH attention, we train a language model on examples of the above form where eachw
is of length511(so the whole input0w0wis of length1024). As this is a language modeling task,
we always predict the next symbol given all the previous ones, but we mask the loss and accuracy to
only consider positions in the second half of the input, i.e., those that can actually be predicted.
The above task can be solved perfectly (to accuracy 100% and loss 0) by a 1-layer Transformer
model. Note though, that it requires non-local attention lookups, so it cannot be solved by any
model relying on sparse attention with a limited span. To make it easy and fast to train but similar
to models used in NLP, we use a 1-layer Transformer withdmodel=df f= 256and4heads. We
train it for 150K steps in4different settings: with full attention, LSH attention withnrounds= 1,
nrounds= 2andnrounds= 4.
From the results summarized in Table 2 we see that a model trained with full attention can be imme-
diately used with LSH attention, but at some loss of accuracy. When trained from scratch with LSH
attention, the model trained with 4 hashes achieves almost perfect accuracy as well. Interestingly,
the accuracy becomes perfect when evaluated with 8 hashes. It goes down when evaluated with 2 or
1 hashes. Models trained with less hashes show worse results but even the model trained with just 1
hash performs almost perfectly when evaluated with 8 hashes.
5
Published as a conference paper at ICLR 2020
3 REVERSIBLETRANSFORMER
As the above section shows, the complexity of attention can be reduced from square in length to
linear, provided an approximation is acceptable. But it is clear from Table 1 that each eld starts
with abnhlterm: thebnhldk, or alternativelybldmodelcost cannot be avoided. Indeed,
the activations before each layer are already of the sizebldmodel, so the memory use of the whole
model withnllayers is at leastbldmodelnl. Even worse: inside the feed-forward layers of
Transformer this goes up tobldf fnl. In a big Transformer it is usual to setdf f= 4Kand
nl= 16so withl= 64Kthis again would use an impractical16GBof memory
In this section, we show how to reduce this cost by rst dealing with thenlpart of the term using
reversible layers and then showing how chunking can allow us to handle thedf fproblem. The
effects of each of these approaches on memory and time complexity are summarized in Table 3.
RevNets.Reversible residual networks were introduced by Gomez et al. (2017) where it was shown
that they can replace ResNets for image classication. The main idea is to allow the activations at
any given layer to be recovered from the activations at the following layer, using only the model
parameters. Rather than having to checkpoint intermediate values for use in the backward pass,
layers can be reversed one-by-one as back-propagation proceeds from the output of the network to
its input. Whereas a normal residual layer performs a functionx7!ythat operates on a single input
and produces a single output and has the formy=x+F(x), a reversible layer works on pairs of
inputs/outputs:(x1; x2)7!(y1; y2), and follows the equations:
y1=x1+F(x2) y2=x2+G(y1) (7)
A layer can be reversed by subtracting (rather than adding) the residuals:
x2=y2 G(y1) x1=y1 F(x2) (8)
Reversible Transformer.We apply the RevNet idea to the Transformer by combining the attention
and feed-forward layers inside the revnet block. In the notation above, F becomes an attention layer
while G becomes the feed-forward layer. Note that Layer Normalization (Ba et al., 2016) is moved
inside the residual blocks.
Y1=X1+ Attention(X2) Y2=X2+ FeedForward(Y1) (9)
The reversible Transformer does not need to store activations in each layer and so gets rid of thenl
term. In Section 5 we show that it performs the same as the normal Transformer when using the
same number of parameters; we achieve this by having bothx1andx2have sizedmodel.
Chunking.While reversibility covers thenlterm, the thicker layers can still use a lot of memory.
The feed-forward layer in particular can use intermediate vectors of dimensionalitydf f= 4Kor
higher. However, computations in feed-forward layers are completely independent across positions
in a sequence, so the computation can be split intocchunks:
Y2=
h
Y
(1)
2
;: : :;Y
(c)
2
i
=
h
X
(1)
2
+ FeedForward(Y
(1)
1
);: : :;X
(c)
2
+ FeedForward(Y
(c)
1
)
i
(10)
This layer is typically batched by performing operations for all positions in parallel, but operating
on one chunk at a time can reduce memory. The reverse computation in (8) and the backward pass
are also chunked. In addition to the feed-forward layers, for models with large vocabulary (more
thandmodelword types) we also chunk the log-probabilities at the output and calculate the loss for
sections of the sequence at a time.
Chunking, large batches and parameter reuse.With chunking and reversible layers the memory
we use for activations in the whole network is independent of the number of layers. The same is
not true for parameters though as their number grows with the number of layers. This problem is
remedied though because we can swap layer parameters to and from CPU memory when this layer
is not computing. In a standard Transformer this would be inefcient because memory transfer to
CPU is slow. The batch size multiplied by length in Reformer is much larger though and therefore
the amount of compute done with the parameters amortizes the cost of their transfer.
6
Published as a conference paper at ICLR 2020
Table 3: Memory and time complexity of Transformer variants. We writedmodelanddf ffor model
depth and assumedf fdmodel;bstands for batch size,lfor length,nlfor the number of layers.
We assumenc=l=32so4l=nc= 128and we writec= 128
2
.
Model Type Memory Complexity Time Complexity
Transformer max(bldf f; bnhl
2
)nl (bldf f+bnhl
2
)nl
Reversible Transformer max(bldf f; bnhl
2
) (bnhldf f+bnhl
2
)nl
Chunked Reversible Transformermax(bldmodel; bnhl
2
) (bnhldf f+bnhl
2
)nl
LSH Transformer max(bldf f; bnhlnrc)nl(bldf f+bnhnrlc)nl
Reformer max(bldmodel; bnhlnrc) (bldf f+bnhnrlc)nl
4 RELATEDWORK
The Transformer model introduced in (Vaswani et al., 2017) has been used widely in natural lan-
guage tasks and further extended to model diverse data such as music scores (Huang et al., 2018),
and images (Parmar et al., 2018; Ramachandran et al., 2019). Most notably, this model class has
been applied successfully in the self-supervised training of extremely large language models (Devlin
et al., 2018; Radford et al., 2019).
Given the enormous computational requirements of state of the art sequence models, there has been
increasing interest in nding methods to reduce the memory footprint and computational require-
ments of Transformer models. In addition to standard methods such as precision reduction and
gradient checkpointing (Sohoni et al., 2019), more efcient versions of the Transformer model's
self-attention mechanism (Sukhbaatar et al., 2019a;b) have also recently been explored.
In particular, leveraging sparsity in the attention layers has proved fruitful. OpenAI introduced the
sparse Transformer (Child et al., 2019) which exploits a factorized sparse representation of atten-
tion. Using product-key attention to increase the key space has also been used to reduce memory
requirements in the feed-forward layers with no loss in performance (Lample et al., 2019).
Locality-sensitive hashing (LSH) has, to our knowledge, not been directly applied to Transformer
attention layers before. But previous work using external memory with neural networks has dealt
with memories of large sizes. The original implementation of memory networks (Weston et al.,
2014) and later work on scaling it (Bordes et al., 2015; Chandar et al., 2016) used memory with size
in the millions. The cost of doing so is that the memory must be xed prior to training. Moreover,
since during the beginning of training the model is unlikely to query the memory correctly, strong
supervision is used to encourage the model to query memory locations that are useful. These hints
are either given as additional supervising information by the task or determined heuristically as in
Hill et al. (2015). The requirement that the memory be xed before has been removed in Santoro
et al. (2016) at the cost of memory size and later alleviated by Rae et al. (2016). The last paper
considered memory lookups with approximate nearest neighbors including both LSH and random
kd-trees, but only for lookups in external memory.
5 EXPERIMENTS
In this section we present experimental results demonstrating the techniques described above. We
analyze the techniques one-by-one to make clear which combinations have impact on performance.
We start by showing that reversible layers and shared query-key spaces do not impact performance,
then proceed to analyze hashing attention and nally the full Reformer model.
We ran our experiments on the imagenet64 and enwik8-64K tasks, where the latter is a variant
of enwik8 that is chunked into subsequences of2
16
= 64Ktokens. We use 3-layer models for
our ablations so as to make it tractable to compare with the regular Transformer, which has high
memory usage and performs fullO(l
2
)attention. All experiments havedmodel= 1024,df f= 4096,
nheads= 8, and a total batch size of 8 sequences. We used the Adafactor optimizer (Shazeer
& Stern, 2018) for training these models. We also evaluate on the WMT 2014 English-to-German
translation task, following the hyperparameters of Vaswani et al. (2017). Training for all experiments
7
Published as a conference paper at ICLR 2020
Figure 3: Effect of shared query-key space (left) and reversibility (right) on performance on enwik8
and imagenet64 training. The curves show bits per dim on held-out data.
Table 4: BLEU scores on newstest2014 for WMT English-German (EnDe). We additionally report
detokenized BLEU scores as computed by sacreBLEU (Post, 2018).
sacreBLEU
Model BLEU Uncased
3
Cased
4
Vaswani et al. (2017), base model 27.3
Vaswani et al. (2017), big 28.4
Ott et al. (2018), big 29.3
Reversible Transformer (base, 100K steps) 27.6 27.4 26.9
Reversible Transformer (base, 500K steps, no weight sharing) 28.027.9 27.4
Reversible Transformer (big, 300K steps, no weight sharing) 29.128.9 28.4
was parallelized across 8 devices (8 GPUs or 8 TPU v3 cores). Code for training our models is made
publicly available.
2
Effect of sharing QK.We rst consider the effect of shared-QK attention on a regular Transformer
model. Shared-QK attention setskj=
qj
kqjk
and prevents tokens from attending to themselves
(except when no other context is available). In the left part of Figure 3, we plot perplexity curves
for both regular and shared-QK attention. A shared query-key space does not perform worse than
regular attention; in fact, for enwik8 it appears to train slightly faster. In other words, we are not
sacricing accuracy by switching to shared-QK attention.
Effect of reversible layers.In the two plots on the right in Figure 3, we compare a regular Trans-
former per Vaswani et al. (2017) with the reversible one describe in Section 3. The two models have
identical parameter counts, and the learning curves likewise appear to be nearly the same. These
results show that the memory savings in the reversible Transformer do not come at the expense of
accuracy.
Reversible layers in machine translation.We also evaluate reversible layers in the context of an
encoder-decoder Transformer model for machine translation from English to German. We start by
making both the encoder and the decoder fully reversible in the Transformer-base architecture, and
2
https://github.com/google/trax/tree/master/trax/models/reformer
3
BLEU+case.lc+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.intl+version.1.4.3
4
BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.intl+version.1.4.3
8
Published as a conference paper at ICLR 2020
Figure 4: LSH attention performance as a function of hashing rounds on imagenet64.Figure 5: Left: LSH attention performance as a function of number of layers on enwik8. Right:
Speed of attention evaluation as a function of input length for full- and LSH- attention.
see that the resulting model performs comparably to Vaswani et al. (2017) when trained for 100K
steps. We also evaluate training for a greater number of steps and with a larger model. Reformer
models are very memory-efcient, so for the latter two experiments we do not need to save mem-
ory by sharing embedding and output projection weight matrices throughout the model. Results
are shown in Table 4. We do not apply LSH attention in this setting because examples are single
sentences, and sentences tend to be relatively short. Our typical LSH attention conguration uses
chunks of 128 tokens after hashing and sorting, whereas the examples in the WMT14 test set are all
shorter than 128 tokens.
LSH attention in Transformer.LSH attention is an approximation for full attention that, as evi-
denced in Figure 4, becomes more accurate as the number of hashes increases. Atnrounds= 8, it
already almost matches full attention. The computational cost of a model grows with the number
of hashes, so this hyperparameter can be adjusted depending on the available compute budget. Ad-
ditionally, as in Table 2, the number of hashes can be increased at evaluation time to produce more
accurate results. On the right half of Figure 5, we plot the speed of different attention types vs. the
sequence length, while holding the total number of tokens xed. We see that while regular attention
becomes slower at longer sequence length, LSH attention speed remains at.
Large Reformer models.To verify that the Reformer can indeed t large models on a single core
and train fast on long sequences, we train up to 20-layer big Reformers on enwik8 and imagenet64.
As can be seen in Figure 5, these models t into memory and train. We were not able to train Trans-
former baselines in this case as they are too slow and memory-hungry, but we see clear improvement
with the number of layers. A 12-layer model on enwik8 trained for 20K steps with a dropout rate
of 0.1 achieves 1.19 bits/dim on the test set. We also trained a 12-layer Reformer model for longer
with further tuning and improvements and we reached 1.05 bits/dim on the enwiki8 test set.
9
Published as a conference paper at ICLR 2020
6 CONCLUSION
Reformer combines the modeling capacity of a Transformer with an architecture that can be executed
efciently on long sequences and with small memory use even for models with a large number of
layers. We believe that this will help large, richly-parameterized Transformer models become more
widespread and accessible. Also, the ability to handle long sequences opens the way for the use
of the Reformer on many generative tasks. In addition to generating very long coherent text, the
Reformer can bring the power of Transformer models to other domains like time-series forecasting,
music, image and video generation.
REFERENCES
Rami Al-Rfou, Dokook Choe, Noah Constant, Mandy Guo, and Llion Jones. Character-level
language modeling with deeper self-attention.CoRR, abs/1808.04444, 2018. URLhttp:
//arxiv.org/abs/1808.04444 .
Alexandr Andoni, Piotr Indyk, Thijs Laarhoven, Ilya P. Razenshteyn, and Ludwig Schmidt. Practical
and optimal LSH for angular distance.CoRR, abs/1509.02897, 2015. URLhttp://arxiv.
org/abs/1509.02897.
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization.arXiv preprint
arXiv:1607.06450, 2016. URLhttp://arxiv.org/abs/1607.06450 .
Antoine Bordes, Nicolas Usunier, Sumit Chopra, and Jason Weston. Large-scale simple question
answering with memory networks.CoRR, abs/1506.02075, 2015. URLhttp://arxiv.org/
abs/1506.02075.
Sarath Chandar, Sungjin Ahn, Hugo Larochelle, Pascal Vincent, Gerald Tesauro, and Yoshua Ben-
gio. Hierarchical memory networks.arXiv preprint arXiv:1605.07427, 2016.
Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse
transformers.URL https://openai.com/blog/sparse-transformers, 2019.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep
bidirectional transformers for language understanding.CoRR, abs/1810.04805, 2018. URL
http://arxiv.org/abs/1810.04805 .
Aidan N Gomez, Mengye Ren, Raquel Urtasun, and Roger B Grosse. The reversible residual net-
work: Backpropagation without storing activations. InAdvances in neural information processing
systems, pp. 22142224, 2017.
Felix Hill, Antoine Bordes, Sumit Chopra, and Jason Weston. The goldilocks principle: Reading
children's books with explicit memory representations.CoRR, abs/1511.02301, 2015. URL
http://arxiv.org/abs/1511.02301 .
Cheng-Zhi Anna Huang, Ashish Vaswani, Jakob Uszkoreit, Noam Shazeer, Curtis Hawthorne, An-
drew M Dai, Matthew D Hoffman, and Douglas Eck. Music transformer: Generating music with
long-term structure.arXiv preprint arXiv:1809.04281, 2018.
Guillaume Lample, Alexandre Sablayrolles, Marc'Aurelio Ranzato, Ludovic Denoyer, and Herv´e
J´egou. Large memory layers with product keys.CoRR, abs/1907.05242, 2019. URLhttp:
//arxiv.org/abs/1907.05242 .
Peter J. Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam
Shazeer. Generating wikipedia by summarizing long sequences.CoRR, abs/1801.10198, 2018.
URLhttp://arxiv.org/abs/1801.10198 .
Myle Ott, Sergey Edunov, David Grangier, and Michael Auli. Scaling neural machine trans-
lation. InProceedings of the Third Conference on Machine Translation: Research Papers,
pp. 19, Brussels, Belgium, October 2018. Association for Computational Linguistics. doi:
10.18653/v1/W18-6301. URLhttps://www.aclweb.org/anthology/W18-6301 .
10
Published as a conference paper at ICLR 2020
Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Lukasz Kaiser, Noam Shazeer, and Alexander Ku.
Image transformer.CoRR, abs/1802.05751, 2018. URLhttp://arxiv.org/abs/1802.
05751.
Matt Post. A call for clarity in reporting BLEU scores. InProceedings of the Third Conference
on Machine Translation: Research Papers, pp. 186191, Belgium, Brussels, October 2018. As-
sociation for Computational Linguistics. URLhttps://www.aclweb.org/anthology/
W18-6319.
Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language
models are unsupervised multitask learners. 2019.
Jack W Rae, Jonathan J Hunt, Tim Harley, Ivo Danihelka, Andrew Senior, Greg Wayne, Alex
Graves, and Timothy P Lillicrap. Scaling memory-augmented neural networks with sparse reads
and writes. InAdvances in Neural Information Processing Systems, (NIPS), 2016.
Prajit Ramachandran, Niki Parmar, Ashish Vaswani, Irwan Bello, Anselm Levskaya, and Jonathon
Shlens. Stand-alone self-attention in vision models.CoRR, abs/1906.05909, 2019. URLhttp:
//arxiv.org/abs/1906.05909 .
Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, and Timothy P. Lillicrap. One-
shot learning with memory-augmented neural networks.CoRR, abs/1605.06065, 2016. URL
http://arxiv.org/abs/1605.06065 .
Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost.
CoRR, abs/1804.04235, 2018. URLhttp://arxiv.org/abs/1804.04235 .
Noam Shazeer, Youlong Cheng, Niki Parmar, Dustin Tran, Ashish Vaswani, Penporn Koanantakool,
Peter Hawkins, HyoukJoong Lee, Mingsheng Hong, Cliff Young, Ryan Sepassi, and Blake Hecht-
man. Mesh-tensorow: Deep learning for supercomputers.CoRR, abs/1811.02084, 2018. URL
http://arxiv.org/abs/1811.02084 .
Nimit Sharad Sohoni, Christopher Richard Aberger, Megan Leszczynski, Jian Zhang, and Christo-
pher R´e. Low-memory neural network training: A technical report.CoRR, abs/1904.10631, 2019.
URLhttp://arxiv.org/abs/1904.10631 .
Sainbayar Sukhbaatar, Edouard Grave, Piotr Bojanowski, and Armand Joulin. Adaptive atten-
tion span in transformers.CoRR, abs/1905.07799, 2019a. URLhttp://arxiv.org/abs/
1905.07799.
Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herv´e J´egou, and Armand Joulin. Aug-
menting self-attention with persistent memory.CoRR, abs/1907.01470, 2019b. URLhttp:
//arxiv.org/abs/1907.01470 .
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez,
Lukasz Kaiser, and Illia Polosukhin. Attention is all you need.CoRR, 2017. URLhttp:
//arxiv.org/abs/1706.03762 .
Jason Weston, Sumit Chopra, and Antoine Bordes. Memory networks.CoRR, abs/1410.3916, 2014.
URLhttp://arxiv.org/abs/1410.3916 .
11
Published as a conference paper at ICLR 2020
A MULTI-ROUNDLSH ATTENTION
In this section we describe in more detail the multi-hash version of our LSH attention mechanism.
We rst repeat Equation (3) from the main text, which describes a general formulation of attention
with sparsity:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vjwherem(j;Pi) =
1ifj =2 Pi
0otherwise
(3)
In the multi-round case, a query positionican attend to key positionsPias dened in (6), which we
also repeat here:
Pi=
nrounds[
r=1
P
(r)
i
whereP
(r)
i
=
n
j:h
(r)
(qi) =h
(r)
(qj)
o
(6)
For batching purposes, attention is performed on chunks of sorted queries/keys:
e
P
(r)
i
=
(
j:
$
s
(r)
i
m
%
1
$
s
(r)
j
m
%
$
s
(r)
i
m
%)
(11)
Combining (3) and (6) gives:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vj (12)
=
nroundsX
r=1
exp
z(i;P
(r)
i
) z(i;Pi)
X
j2
e
P
(r)
i
1
Ni;j
exp
qikj m(j;P
(r)
i
) z(i;P
(r)
i
)
vj
(13)
=
nroundsX
r=1
exp
z(i;P
(r)
i
) z(i;Pi)
o
(r)
i
(14)
o
(r)
i
=
X
j2
e
P
(r)
i
exp
qikj m
(r)
i;j
z(i;P
(r)
i
)
vj (15)
whereNi;j=
n
r
0
:j2 P
(r
0
)
i
o
andm
(r)
i;j
=
8
>
<
>
:
1 ifj =2 P
(r)
i
10
5
ifi=j
logNi;jotherwise
(16)
Each round of LSH attention produces a vectoro
(r)
i
that can be computed independently from other
rounds, except for the inclusion of a termNi;jto avoid double-counting elements when constructing
the union ofP
(r)
i
sets. In our implementation we fold theNi;jfactor into the masking termm
(r)
i;j
.
We also modifym
(r)
i;j
to introduce a special case fori=j. This case is added because causal
masking in a standard Transformer allows positionito attend to itself, which is not desirable in a
shared-QK formulation. We set the mask to a large but nite value to disallow attention-in-place,
except in the situation where a token has no other valid attention targets. For example, the rst token
in a sequence attends only to itself, because no prior context is available.
12
REFORMER: THEEFFICIENTTRANSFORMER
Nikita Kitaev
U.C. Berkeley & Google Research
kitaev@cs.berkeley.edu
ukasz Kaiser
Google Research
flukaszkaiser,levskaya g@google.com
Anselm Levskaya
Google Research
ABSTRACT
Large Transformer models routinely achieve state-of-the-art results on a number
of tasks but training these models can be prohibitively costly, especially on long
sequences. We introduce two techniques to improve the efciency of Transform-
ers. For one, we replace dot-product attention by one that uses locality-sensitive
hashing, changing its complexity from O(L
2
) to O(LlogL), whereLis the length
of the sequence. Furthermore, we use reversible residual layers instead of the
standard residuals, which allows storing activations only once in the training pro-
cess instead ofNtimes, whereNis the number of layers. The resulting model,
the Reformer, performs on par with Transformer models while being much more
memory-efcient and much faster on long sequences.
1 INTRODUCTION
The Transformer architecture (Vaswani et al., 2017) is widely used in natural language processing
and yields state-of-the-art results on a number of tasks. To obtain these results, researchers have
resorted to training ever larger Transformer models. The number of parameters exceeds 0.5B per
layer in the largest conguration reported in (Shazeer et al., 2018) while the number of layers goes
up to 64 in (Al-Rfou et al., 2018). Transformer models are also used on increasingly long sequences.
Up to 11 thousand tokens of text in a single example were processed in (Liu et al., 2018) and when
processing other modalities, like music (Huang et al., 2018) and images (Parmar et al., 2018), even
longer sequences are commonplace. These large-scale long-sequence models yield great results but
strain resources to the point where some argue that this trend is breaking NLP research
1
. Many
large Transformer models can only realistically be trained in large industrial research laboratories
and such models trained with model parallelism cannot even be ne-tuned on a single GPU as their
memory requirements demand a multi-accelerator hardware setup even for a single training step.
Do large Transformer models fundamentally require such huge resources or are they simply inef-
cient? Consider the following calculation: the 0.5B parameters used in the largest reported Trans-
former layer account for 2GB of memory. Activations for 64K tokens with embedding size 1024
and batch size 8 account for64K1K8 = 0:5B oats, requiring another 2GB of memory. If
our memory use was only per-layer, then we should fairly easily t a large Transformer even on
sequences of length 64K on a single accelerator. Further, the whole corpus used to train BERT
only requires 17GB to store. Why is it then that we cannot even ne-tune these models on single
machines?
The above estimate includes only per-layer memory and input activations cost and does not take into
account the following major sources of memory use in the Transformer.
Memory in a model withNlayers isN-times larger than in a single-layer model due to the
fact that activations need to be stored for back-propagation.
Since the depthdf fof intermediate feed-forward layers is often much larger than the depth
dmodelof attention activations, it accounts for a large fraction of memory use.
Attention on sequences of lengthLis O(L
2
) in both computational and memory complex-
ity, so even for a single sequence of64K tokens can exhaust accelerator memory.
Equal Contribution
1
https://hackingsemantics.xyz/2019/leaderboards/
1
Published as a conference paper at ICLR 2020
We introduce the Reformer model which solves these problems using the following techniques:
Reversible layers, rst introduced in Gomez et al. (2017), enable storing only a single copy
of activations in the whole model, so theNfactor disappears.
Splitting activations inside feed-forward layers and processing them in chunks removes the
df ffactor and saves memory inside feed-forward layers.
Approximate attention computation based on locality-sensitive hashing replaces the O(L
2
)
factor in attention layers with O(LlogL) and so allows operating on long sequences.
We study these techniques and show that they have negligible impact on the training process com-
pared to the standard Transformer. Splitting activations in fact only affects the implementation; it is
numerically identical to the layers used in the Transformer. Applying reversible residuals instead of
the standard ones does change the model but has a negligible effect on training in all congurations
we experimented with. Finally, locality-sensitive hashing in attention is a more major change that
can inuence the training dynamics, depending on the number of concurrent hashes used. We study
this parameter and nd a value which is both efcient to use and yields results very close to full
attention.
We experiment on a synthetic task, a text task (enwik8) with sequences of length 64K and an image
generation task (imagenet-64 generation) with sequences of length 12K. In both cases we show that
Reformer matches the results obtained with full Transformer but runs much faster, especially on the
text task, and with orders of magnitude better memory efciency.
2 LOCALITY-SENSITIVEHASHINGATTENTION
Dot-product attention.The standard attention used in the Transformer is the scaled dot-product
attention (Vaswani et al., 2017). The input consists of queries and keys of dimensiondk, and values
of dimensiondv. The dot products of the query with all keys are computed, scaled by
p
dk, and a
softmax function is applied to obtain the weights on the values. In practice, the attention function
on a set of queries is computed simultaneously, packed together into a matrixQ. Assuming the keys
and values are also packed together into matricesKandV, the matrix of outputs is dened as:
Attention(Q; K; V) = softmax(
QK
T
p
dk
)V (1)
Multi-head attention.In the Transformer, instead of performing a single attention function with
dmodel-dimensional keys, values and queries, one linearly projects the queries, keys and valuesh
times with different, learned linear projections todk,dkanddvdimensions, respectively. Attention
is applied to each of these projected versions of queries, keys and values in parallel, yieldingdv-
dimensional output values. These are concatenated and once again projected, resulting in the nal
values. This mechanism is known as multi-head attention.
Memory-efcient attention.To calculate the memory use of the attention mechanism, let us
focus on the attention computation from Equation 1. Let us assume that Q, K and V all have
the shape[batchsize; length; dmodel]. The main issue is the termQK
T
, which has the shape
[batchsize; length; length]. In the experimental section we train a model on sequences of length
64K in this case, even at batch-size of 1, this is a64K64Kmatrix, which in 32-bit oats would
take 16GB of memory. This is impractical and has hindered the use of the Transformer for long
sequences. But it is important to note that theQK
T
matrix does not need to be fully materialized
in memory. The attention can indeed be computed for each queryqiseparately, only calculating
softmax(
qiK
T
p
dk
)Vonce in memory, and then re-computing it on the backward pass when needed for
gradients. This way of computing attention may be less efcient but it only uses memory propor-
tional tolength. We use this memory-efcient implementation of attention to run the full-attention
baselines presented in the experimental section.
Where do Q, K, V come from?The multi-head attention described above operates on keys,
queries and values, but usually we are only given a single tensor of activations A of the shape
[batchsize; length; dmodel] e.g., coming from embedding the tokens in a sentence into vectors.
2
Published as a conference paper at ICLR 2020
Figure 1: An angular locality sensitive hash uses random rotations of spherically projected points to
establish buckets by an argmax over signed axes projections. In this highly simplied 2D depiction,
two pointsxandyare unlikely to share the same hash buckets (above) for the three different angular
hashes unless their spherical projections are close to one another (below).
To build Q, K and V from A, the Transformer uses 3 different linear layers projecting A into Q, K
and V with different parameters. For models with LSH attention, we want queries and keys (Q and
K) to be identical. This is easily achieved by using the same linear layer to go from A to Q and
K, and a separate one for V. We call a model that behaves like this a shared-QK Transformer. It
turns out that sharing QK does not affect the performance of Transformer, even if we additionally
normalize the length of the keys K, as we show in the experimental Section 5.
Hashing attention.For the LSH attention, we start with two tensors, Q=K and V of the shape
[batchsize; length; dmodel]. We keep the multi-head mechanism intact and focus on the atten-
tion computation from Equation 1. As already mentioned, the main issue is the termQK
T
,
which has the shape[batchsize; length; length]. But note that we are actually only interested
insoftmax(QK
T
). Since softmax is dominated by the largest elements, for each queryqiwe only
need to focus on the keys in K that are closest toqi. For example, if K is of length 64K, for eachqi
we could only consider a small subset of, say, the32or64closest keys. That is much more efcient,
but how can we nd the nearest neighbors among the keys?
Locality sensitive hashing.The problem of nding nearest neighbors quickly in high-dimensional
spaces can be solved by locality-sensitive hashing (LSH). A hashing scheme that assigns each vector
xto a hashh(x)is called locality-sensitive if nearby vectors get the same hash with high probability
and distant ones do not. In our case, we actually only require that nearby vectors get the same hash
with high probability and that hash-buckets are of similar size with high probability.
We achieve this by employing random projections as follows (see Figure 1). To getbhashes, we
rst x a random matrixRof size[dk; b=2]. We then deneh(x) = arg max([xR; xR])where
[u;v]denotes the concatenation of two vectors. This method is a known LSH scheme (Andoni et al.,
2015) and is easy to implement and apply to batches of vectors.
LSH attention.Knowing our LSH scheme and the general idea of hashing attention, we will now
formalize the LSH attention we use in this paper. We rst rewrite the equation for normal attention,
(1), for a single query positioniat a time:
oi=
X
j2Pi
exp (qikj z(i;Pi))vj wherePi=fj:ijg (2)
We introduce the notationPito represent the set that the query at positioniattends to, andzto
denote the partition function (i.e. the normalizing term in the softmax). For clarity, we also omit
scaling by
p
dk.
For batching purposes we typically perform attention over a larger set
e
Pi=f0;1; : : : ; lg Pi
while masking out elements not inPi:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vjwherem(j;Pi) =
1ifj =2 Pi
0otherwise
(3)
3
Published as a conference paper at ICLR 2020
Figure 2: Simplied depiction of LSH Attention showing the hash-bucketing, sorting, and chunking
steps and the resulting causal attentions. (a-d) Attention matrices for these varieties of attention.
Now we turn to LSH attention, which we can think of in terms of restricting the setPiof target
items a query positionican attend to, by only allowing attention within a single hash bucket.
Pi=fj:h(qi) =h(kj)g (4)
Figure 2(a-b) shows a schematic comparison of full-attention with a hashed variant. Part (a) depicts
that the attention matrix for full attention is typically sparse, but the computation does not take
advantage of this sparsity. In (b), the queries and keys have been sorted according to their hash
bucket. Since similar items fall in the same bucket with high probability, the full attention pattern
can be approximated by only allowing attention within each bucket.
Hash buckets in this formulation tend to be uneven in size, which makes it difcult to batch across
buckets. Moreover, the number of queries and the number of keys within a bucket may be unequal
in fact, it is possible for a bucket to contain many queries but no keys. To alleviate these issues, we
rst ensure thath(kj) =h(qj)by settingkj=
qj
kqjk
. Next, we sort the queries by bucket number
and, within each bucket, by sequence position; this denes a permutation wherei7!siafter sorting.
In the sorted attention matrix, pairs from the same bucket will cluster near the diagonal (as depicted
in Figure 2c). We can follow a batching approach where chunks ofmconsecutive queries (after
sorting) attend to each other, and one chunk back (Figure 2d). Following our earlier notation, this
corresponds to setting:
e
Pi=
n
j:
j
si
m
k
1
j
sj
m
k
j
si
m
ko
(5)
IfmaxijPij< m, thenPi
e
Pi. In practice we setm=
2l
nbuckets
(wherelis the sequence length).
The average bucket size is
l
nbuckets
, and we assume that the probability of a bucket growing to twice
that size is sufciently low. The overall process of LSH attention is summarized in Figure 2.
Multi-round LSH attention.With hashing, there is always a small probability that similar items
nevertheless fall in different buckets. This probability can be reduced by doing multiple rounds of
hashing withnroundsdistinct hash functionsfh
(1)
; h
(2)
; : : :g, such that:
Pi=
nrounds[
r=1
P
(r)
i
whereP
(r)
i
=
n
j:h
(r)
(qi) =h
(r)
(qj)
o
(6)
The multi-round case essentially involves performing LSH attentionnroundstimes in parallel; the
details of the procedure are described in in Appendix A.
Causal masking for shared-QK attention.In a Transformer decoder, masking (denoted by
m(j;Pi)in Equation 3) is used to prevent positions from attending into the future. To implement
masking in LSH attention, we associate every query/key vector with a position index, re-order the
position indices using the same permutations used to sort the query/key vectors, and then use a
comparison operation to compute the mask.
4
Published as a conference paper at ICLR 2020
Table 1: Memory and time complexity of attention variants. We writelfor length,bfor batch size,
nhfor the number of heads,ncfor the number of LSH chunks,nrfor the number of hash repetitions.
Attention Type Memory Complexity Time Complexity
Scaled Dot-Product max(bnhldk; bnhl
2
) max( bnhldk; bnhl
2
)
Memory-Efcient max(bnhldk; bnhl
2
) max( bnhldk; bnhl
2
)
LSH Attention max(bnhldk; bnhlnr(4l=nc)
2
) max(bnhldk; bnhnrl(4l=nc)
2
)
Table 2: Accuracies on the duplication task of a 1-layer Transformer model with full attention and
with locality-sensitive hashing attention using different number of parallel hashes.
Train
Eval
Full AttentionLSH-8LSH-4LSH-2LSH-1
Full Attention100% 94.8%92.5%76.9%52.5%
LSH-4 0.8% 100%99.9%99.4%91.9%
LSH-2 0.8% 100%99.9%98.1%86.8%
LSH-1 0.8% 99.9%99.6%94.8%77.9%
While attention to the future is not allowed, typical implementations of the Transformerdoallow
a position to attend toitself. Such behavior is undesirable in a shared-QK formulation because the
dot-product of a query vector with itself will almost always be greater than the dot product of a
query vector with a vector at another position. We therefore modify the masking to forbid a token
from attending to itself, except in situations where a token has no other valid attention targets (e.g.
the rst token in a sequence).
2.1 ANALYSIS ON A SYNTHETIC TASK
To verify the performance of LSH attention and study its behavior, we start with the following
synthetic task: duplicate a sequence of symbols. In this task, each training and testing example has
the form0w0wwherew2 f1; : : : ; Ng
is a sequence of symbols ranging from1toN(we use
N= 127in our experiments). An example with the wordwof length3is given below.
Example:0191137201911372
To study LSH attention, we train a language model on examples of the above form where eachw
is of length511(so the whole input0w0wis of length1024). As this is a language modeling task,
we always predict the next symbol given all the previous ones, but we mask the loss and accuracy to
only consider positions in the second half of the input, i.e., those that can actually be predicted.
The above task can be solved perfectly (to accuracy 100% and loss 0) by a 1-layer Transformer
model. Note though, that it requires non-local attention lookups, so it cannot be solved by any
model relying on sparse attention with a limited span. To make it easy and fast to train but similar
to models used in NLP, we use a 1-layer Transformer withdmodel=df f= 256and4heads. We
train it for 150K steps in4different settings: with full attention, LSH attention withnrounds= 1,
nrounds= 2andnrounds= 4.
From the results summarized in Table 2 we see that a model trained with full attention can be imme-
diately used with LSH attention, but at some loss of accuracy. When trained from scratch with LSH
attention, the model trained with 4 hashes achieves almost perfect accuracy as well. Interestingly,
the accuracy becomes perfect when evaluated with 8 hashes. It goes down when evaluated with 2 or
1 hashes. Models trained with less hashes show worse results but even the model trained with just 1
hash performs almost perfectly when evaluated with 8 hashes.
5
Published as a conference paper at ICLR 2020
3 REVERSIBLETRANSFORMER
As the above section shows, the complexity of attention can be reduced from square in length to
linear, provided an approximation is acceptable. But it is clear from Table 1 that each eld starts
with abnhlterm: thebnhldk, or alternativelybldmodelcost cannot be avoided. Indeed,
the activations before each layer are already of the sizebldmodel, so the memory use of the whole
model withnllayers is at leastbldmodelnl. Even worse: inside the feed-forward layers of
Transformer this goes up tobldf fnl. In a big Transformer it is usual to setdf f= 4Kand
nl= 16so withl= 64Kthis again would use an impractical16GBof memory
In this section, we show how to reduce this cost by rst dealing with thenlpart of the term using
reversible layers and then showing how chunking can allow us to handle thedf fproblem. The
effects of each of these approaches on memory and time complexity are summarized in Table 3.
RevNets.Reversible residual networks were introduced by Gomez et al. (2017) where it was shown
that they can replace ResNets for image classication. The main idea is to allow the activations at
any given layer to be recovered from the activations at the following layer, using only the model
parameters. Rather than having to checkpoint intermediate values for use in the backward pass,
layers can be reversed one-by-one as back-propagation proceeds from the output of the network to
its input. Whereas a normal residual layer performs a functionx7!ythat operates on a single input
and produces a single output and has the formy=x+F(x), a reversible layer works on pairs of
inputs/outputs:(x1; x2)7!(y1; y2), and follows the equations:
y1=x1+F(x2) y2=x2+G(y1) (7)
A layer can be reversed by subtracting (rather than adding) the residuals:
x2=y2 G(y1) x1=y1 F(x2) (8)
Reversible Transformer.We apply the RevNet idea to the Transformer by combining the attention
and feed-forward layers inside the revnet block. In the notation above, F becomes an attention layer
while G becomes the feed-forward layer. Note that Layer Normalization (Ba et al., 2016) is moved
inside the residual blocks.
Y1=X1+ Attention(X2) Y2=X2+ FeedForward(Y1) (9)
The reversible Transformer does not need to store activations in each layer and so gets rid of thenl
term. In Section 5 we show that it performs the same as the normal Transformer when using the
same number of parameters; we achieve this by having bothx1andx2have sizedmodel.
Chunking.While reversibility covers thenlterm, the thicker layers can still use a lot of memory.
The feed-forward layer in particular can use intermediate vectors of dimensionalitydf f= 4Kor
higher. However, computations in feed-forward layers are completely independent across positions
in a sequence, so the computation can be split intocchunks:
Y2=
h
Y
(1)
2
;: : :;Y
(c)
2
i
=
h
X
(1)
2
+ FeedForward(Y
(1)
1
);: : :;X
(c)
2
+ FeedForward(Y
(c)
1
)
i
(10)
This layer is typically batched by performing operations for all positions in parallel, but operating
on one chunk at a time can reduce memory. The reverse computation in (8) and the backward pass
are also chunked. In addition to the feed-forward layers, for models with large vocabulary (more
thandmodelword types) we also chunk the log-probabilities at the output and calculate the loss for
sections of the sequence at a time.
Chunking, large batches and parameter reuse.With chunking and reversible layers the memory
we use for activations in the whole network is independent of the number of layers. The same is
not true for parameters though as their number grows with the number of layers. This problem is
remedied though because we can swap layer parameters to and from CPU memory when this layer
is not computing. In a standard Transformer this would be inefcient because memory transfer to
CPU is slow. The batch size multiplied by length in Reformer is much larger though and therefore
the amount of compute done with the parameters amortizes the cost of their transfer.
6
Published as a conference paper at ICLR 2020
Table 3: Memory and time complexity of Transformer variants. We writedmodelanddf ffor model
depth and assumedf fdmodel;bstands for batch size,lfor length,nlfor the number of layers.
We assumenc=l=32so4l=nc= 128and we writec= 128
2
.
Model Type Memory Complexity Time Complexity
Transformer max(bldf f; bnhl
2
)nl (bldf f+bnhl
2
)nl
Reversible Transformer max(bldf f; bnhl
2
) (bnhldf f+bnhl
2
)nl
Chunked Reversible Transformermax(bldmodel; bnhl
2
) (bnhldf f+bnhl
2
)nl
LSH Transformer max(bldf f; bnhlnrc)nl(bldf f+bnhnrlc)nl
Reformer max(bldmodel; bnhlnrc) (bldf f+bnhnrlc)nl
4 RELATEDWORK
The Transformer model introduced in (Vaswani et al., 2017) has been used widely in natural lan-
guage tasks and further extended to model diverse data such as music scores (Huang et al., 2018),
and images (Parmar et al., 2018; Ramachandran et al., 2019). Most notably, this model class has
been applied successfully in the self-supervised training of extremely large language models (Devlin
et al., 2018; Radford et al., 2019).
Given the enormous computational requirements of state of the art sequence models, there has been
increasing interest in nding methods to reduce the memory footprint and computational require-
ments of Transformer models. In addition to standard methods such as precision reduction and
gradient checkpointing (Sohoni et al., 2019), more efcient versions of the Transformer model's
self-attention mechanism (Sukhbaatar et al., 2019a;b) have also recently been explored.
In particular, leveraging sparsity in the attention layers has proved fruitful. OpenAI introduced the
sparse Transformer (Child et al., 2019) which exploits a factorized sparse representation of atten-
tion. Using product-key attention to increase the key space has also been used to reduce memory
requirements in the feed-forward layers with no loss in performance (Lample et al., 2019).
Locality-sensitive hashing (LSH) has, to our knowledge, not been directly applied to Transformer
attention layers before. But previous work using external memory with neural networks has dealt
with memories of large sizes. The original implementation of memory networks (Weston et al.,
2014) and later work on scaling it (Bordes et al., 2015; Chandar et al., 2016) used memory with size
in the millions. The cost of doing so is that the memory must be xed prior to training. Moreover,
since during the beginning of training the model is unlikely to query the memory correctly, strong
supervision is used to encourage the model to query memory locations that are useful. These hints
are either given as additional supervising information by the task or determined heuristically as in
Hill et al. (2015). The requirement that the memory be xed before has been removed in Santoro
et al. (2016) at the cost of memory size and later alleviated by Rae et al. (2016). The last paper
considered memory lookups with approximate nearest neighbors including both LSH and random
kd-trees, but only for lookups in external memory.
5 EXPERIMENTS
In this section we present experimental results demonstrating the techniques described above. We
analyze the techniques one-by-one to make clear which combinations have impact on performance.
We start by showing that reversible layers and shared query-key spaces do not impact performance,
then proceed to analyze hashing attention and nally the full Reformer model.
We ran our experiments on the imagenet64 and enwik8-64K tasks, where the latter is a variant
of enwik8 that is chunked into subsequences of2
16
= 64Ktokens. We use 3-layer models for
our ablations so as to make it tractable to compare with the regular Transformer, which has high
memory usage and performs fullO(l
2
)attention. All experiments havedmodel= 1024,df f= 4096,
nheads= 8, and a total batch size of 8 sequences. We used the Adafactor optimizer (Shazeer
& Stern, 2018) for training these models. We also evaluate on the WMT 2014 English-to-German
translation task, following the hyperparameters of Vaswani et al. (2017). Training for all experiments
7
Published as a conference paper at ICLR 2020
Figure 3: Effect of shared query-key space (left) and reversibility (right) on performance on enwik8
and imagenet64 training. The curves show bits per dim on held-out data.
Table 4: BLEU scores on newstest2014 for WMT English-German (EnDe). We additionally report
detokenized BLEU scores as computed by sacreBLEU (Post, 2018).
sacreBLEU
Model BLEU Uncased
3
Cased
4
Vaswani et al. (2017), base model 27.3
Vaswani et al. (2017), big 28.4
Ott et al. (2018), big 29.3
Reversible Transformer (base, 100K steps) 27.6 27.4 26.9
Reversible Transformer (base, 500K steps, no weight sharing) 28.027.9 27.4
Reversible Transformer (big, 300K steps, no weight sharing) 29.128.9 28.4
was parallelized across 8 devices (8 GPUs or 8 TPU v3 cores). Code for training our models is made
publicly available.
2
Effect of sharing QK.We rst consider the effect of shared-QK attention on a regular Transformer
model. Shared-QK attention setskj=
qj
kqjk
and prevents tokens from attending to themselves
(except when no other context is available). In the left part of Figure 3, we plot perplexity curves
for both regular and shared-QK attention. A shared query-key space does not perform worse than
regular attention; in fact, for enwik8 it appears to train slightly faster. In other words, we are not
sacricing accuracy by switching to shared-QK attention.
Effect of reversible layers.In the two plots on the right in Figure 3, we compare a regular Trans-
former per Vaswani et al. (2017) with the reversible one describe in Section 3. The two models have
identical parameter counts, and the learning curves likewise appear to be nearly the same. These
results show that the memory savings in the reversible Transformer do not come at the expense of
accuracy.
Reversible layers in machine translation.We also evaluate reversible layers in the context of an
encoder-decoder Transformer model for machine translation from English to German. We start by
making both the encoder and the decoder fully reversible in the Transformer-base architecture, and
2
https://github.com/google/trax/tree/master/trax/models/reformer
3
BLEU+case.lc+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.intl+version.1.4.3
4
BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.intl+version.1.4.3
8
Published as a conference paper at ICLR 2020
Figure 4: LSH attention performance as a function of hashing rounds on imagenet64.Figure 5: Left: LSH attention performance as a function of number of layers on enwik8. Right:
Speed of attention evaluation as a function of input length for full- and LSH- attention.
see that the resulting model performs comparably to Vaswani et al. (2017) when trained for 100K
steps. We also evaluate training for a greater number of steps and with a larger model. Reformer
models are very memory-efcient, so for the latter two experiments we do not need to save mem-
ory by sharing embedding and output projection weight matrices throughout the model. Results
are shown in Table 4. We do not apply LSH attention in this setting because examples are single
sentences, and sentences tend to be relatively short. Our typical LSH attention conguration uses
chunks of 128 tokens after hashing and sorting, whereas the examples in the WMT14 test set are all
shorter than 128 tokens.
LSH attention in Transformer.LSH attention is an approximation for full attention that, as evi-
denced in Figure 4, becomes more accurate as the number of hashes increases. Atnrounds= 8, it
already almost matches full attention. The computational cost of a model grows with the number
of hashes, so this hyperparameter can be adjusted depending on the available compute budget. Ad-
ditionally, as in Table 2, the number of hashes can be increased at evaluation time to produce more
accurate results. On the right half of Figure 5, we plot the speed of different attention types vs. the
sequence length, while holding the total number of tokens xed. We see that while regular attention
becomes slower at longer sequence length, LSH attention speed remains at.
Large Reformer models.To verify that the Reformer can indeed t large models on a single core
and train fast on long sequences, we train up to 20-layer big Reformers on enwik8 and imagenet64.
As can be seen in Figure 5, these models t into memory and train. We were not able to train Trans-
former baselines in this case as they are too slow and memory-hungry, but we see clear improvement
with the number of layers. A 12-layer model on enwik8 trained for 20K steps with a dropout rate
of 0.1 achieves 1.19 bits/dim on the test set. We also trained a 12-layer Reformer model for longer
with further tuning and improvements and we reached 1.05 bits/dim on the enwiki8 test set.
9
Published as a conference paper at ICLR 2020
6 CONCLUSION
Reformer combines the modeling capacity of a Transformer with an architecture that can be executed
efciently on long sequences and with small memory use even for models with a large number of
layers. We believe that this will help large, richly-parameterized Transformer models become more
widespread and accessible. Also, the ability to handle long sequences opens the way for the use
of the Reformer on many generative tasks. In addition to generating very long coherent text, the
Reformer can bring the power of Transformer models to other domains like time-series forecasting,
music, image and video generation.
REFERENCES
Rami Al-Rfou, Dokook Choe, Noah Constant, Mandy Guo, and Llion Jones. Character-level
language modeling with deeper self-attention.CoRR, abs/1808.04444, 2018. URLhttp:
//arxiv.org/abs/1808.04444 .
Alexandr Andoni, Piotr Indyk, Thijs Laarhoven, Ilya P. Razenshteyn, and Ludwig Schmidt. Practical
and optimal LSH for angular distance.CoRR, abs/1509.02897, 2015. URLhttp://arxiv.
org/abs/1509.02897.
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization.arXiv preprint
arXiv:1607.06450, 2016. URLhttp://arxiv.org/abs/1607.06450 .
Antoine Bordes, Nicolas Usunier, Sumit Chopra, and Jason Weston. Large-scale simple question
answering with memory networks.CoRR, abs/1506.02075, 2015. URLhttp://arxiv.org/
abs/1506.02075.
Sarath Chandar, Sungjin Ahn, Hugo Larochelle, Pascal Vincent, Gerald Tesauro, and Yoshua Ben-
gio. Hierarchical memory networks.arXiv preprint arXiv:1605.07427, 2016.
Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse
transformers.URL https://openai.com/blog/sparse-transformers, 2019.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep
bidirectional transformers for language understanding.CoRR, abs/1810.04805, 2018. URL
http://arxiv.org/abs/1810.04805 .
Aidan N Gomez, Mengye Ren, Raquel Urtasun, and Roger B Grosse. The reversible residual net-
work: Backpropagation without storing activations. InAdvances in neural information processing
systems, pp. 22142224, 2017.
Felix Hill, Antoine Bordes, Sumit Chopra, and Jason Weston. The goldilocks principle: Reading
children's books with explicit memory representations.CoRR, abs/1511.02301, 2015. URL
http://arxiv.org/abs/1511.02301 .
Cheng-Zhi Anna Huang, Ashish Vaswani, Jakob Uszkoreit, Noam Shazeer, Curtis Hawthorne, An-
drew M Dai, Matthew D Hoffman, and Douglas Eck. Music transformer: Generating music with
long-term structure.arXiv preprint arXiv:1809.04281, 2018.
Guillaume Lample, Alexandre Sablayrolles, Marc'Aurelio Ranzato, Ludovic Denoyer, and Herv´e
J´egou. Large memory layers with product keys.CoRR, abs/1907.05242, 2019. URLhttp:
//arxiv.org/abs/1907.05242 .
Peter J. Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam
Shazeer. Generating wikipedia by summarizing long sequences.CoRR, abs/1801.10198, 2018.
URLhttp://arxiv.org/abs/1801.10198 .
Myle Ott, Sergey Edunov, David Grangier, and Michael Auli. Scaling neural machine trans-
lation. InProceedings of the Third Conference on Machine Translation: Research Papers,
pp. 19, Brussels, Belgium, October 2018. Association for Computational Linguistics. doi:
10.18653/v1/W18-6301. URLhttps://www.aclweb.org/anthology/W18-6301 .
10
Published as a conference paper at ICLR 2020
Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Lukasz Kaiser, Noam Shazeer, and Alexander Ku.
Image transformer.CoRR, abs/1802.05751, 2018. URLhttp://arxiv.org/abs/1802.
05751.
Matt Post. A call for clarity in reporting BLEU scores. InProceedings of the Third Conference
on Machine Translation: Research Papers, pp. 186191, Belgium, Brussels, October 2018. As-
sociation for Computational Linguistics. URLhttps://www.aclweb.org/anthology/
W18-6319.
Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language
models are unsupervised multitask learners. 2019.
Jack W Rae, Jonathan J Hunt, Tim Harley, Ivo Danihelka, Andrew Senior, Greg Wayne, Alex
Graves, and Timothy P Lillicrap. Scaling memory-augmented neural networks with sparse reads
and writes. InAdvances in Neural Information Processing Systems, (NIPS), 2016.
Prajit Ramachandran, Niki Parmar, Ashish Vaswani, Irwan Bello, Anselm Levskaya, and Jonathon
Shlens. Stand-alone self-attention in vision models.CoRR, abs/1906.05909, 2019. URLhttp:
//arxiv.org/abs/1906.05909 .
Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, and Timothy P. Lillicrap. One-
shot learning with memory-augmented neural networks.CoRR, abs/1605.06065, 2016. URL
http://arxiv.org/abs/1605.06065 .
Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost.
CoRR, abs/1804.04235, 2018. URLhttp://arxiv.org/abs/1804.04235 .
Noam Shazeer, Youlong Cheng, Niki Parmar, Dustin Tran, Ashish Vaswani, Penporn Koanantakool,
Peter Hawkins, HyoukJoong Lee, Mingsheng Hong, Cliff Young, Ryan Sepassi, and Blake Hecht-
man. Mesh-tensorow: Deep learning for supercomputers.CoRR, abs/1811.02084, 2018. URL
http://arxiv.org/abs/1811.02084 .
Nimit Sharad Sohoni, Christopher Richard Aberger, Megan Leszczynski, Jian Zhang, and Christo-
pher R´e. Low-memory neural network training: A technical report.CoRR, abs/1904.10631, 2019.
URLhttp://arxiv.org/abs/1904.10631 .
Sainbayar Sukhbaatar, Edouard Grave, Piotr Bojanowski, and Armand Joulin. Adaptive atten-
tion span in transformers.CoRR, abs/1905.07799, 2019a. URLhttp://arxiv.org/abs/
1905.07799.
Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herv´e J´egou, and Armand Joulin. Aug-
menting self-attention with persistent memory.CoRR, abs/1907.01470, 2019b. URLhttp:
//arxiv.org/abs/1907.01470 .
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez,
Lukasz Kaiser, and Illia Polosukhin. Attention is all you need.CoRR, 2017. URLhttp:
//arxiv.org/abs/1706.03762 .
Jason Weston, Sumit Chopra, and Antoine Bordes. Memory networks.CoRR, abs/1410.3916, 2014.
URLhttp://arxiv.org/abs/1410.3916 .
11
Published as a conference paper at ICLR 2020
A MULTI-ROUNDLSH ATTENTION
In this section we describe in more detail the multi-hash version of our LSH attention mechanism.
We rst repeat Equation (3) from the main text, which describes a general formulation of attention
with sparsity:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vjwherem(j;Pi) =
1ifj =2 Pi
0otherwise
(3)
In the multi-round case, a query positionican attend to key positionsPias dened in (6), which we
also repeat here:
Pi=
nrounds[
r=1
P
(r)
i
whereP
(r)
i
=
n
j:h
(r)
(qi) =h
(r)
(qj)
o
(6)
For batching purposes, attention is performed on chunks of sorted queries/keys:
e
P
(r)
i
=
(
j:
$
s
(r)
i
m
%
1
$
s
(r)
j
m
%
$
s
(r)
i
m
%)
(11)
Combining (3) and (6) gives:
oi=
X
j2
e
Pi
exp (qikj m(j;Pi) z(i;Pi))vj (12)
=
nroundsX
r=1
exp
z(i;P
(r)
i
) z(i;Pi)
X
j2
e
P
(r)
i
1
Ni;j
exp
qikj m(j;P
(r)
i
) z(i;P
(r)
i
)
vj
(13)
=
nroundsX
r=1
exp
z(i;P
(r)
i
) z(i;Pi)
o
(r)
i
(14)
o
(r)
i
=
X
j2
e
P
(r)
i
exp
qikj m
(r)
i;j
z(i;P
(r)
i
)
vj (15)
whereNi;j=
n
r
0
:j2 P
(r
0
)
i
o
andm
(r)
i;j
=
8
>
<
>
:
1 ifj =2 P
(r)
i
10
5
ifi=j
logNi;jotherwise
(16)
Each round of LSH attention produces a vectoro
(r)
i
that can be computed independently from other
rounds, except for the inclusion of a termNi;jto avoid double-counting elements when constructing
the union ofP
(r)
i
sets. In our implementation we fold theNi;jfactor into the masking termm
(r)
i;j
.
We also modifym
(r)
i;j
to introduce a special case fori=j. This case is added because causal
masking in a standard Transformer allows positionito attend to itself, which is not desirable in a
shared-QK formulation. We set the mask to a large but nite value to disallow attention-in-place,
except in the situation where a token has no other valid attention targets. For example, the rst token
in a sequence attends only to itself, because no prior context is available.
12