Skip to content

Commit

Permalink
add and default to rotary positional embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 21, 2021
1 parent 365d5c8 commit be3cfac
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 21 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,17 @@ print(sample.shape) # (1, <=100) token ids
}
```

```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
59 changes: 48 additions & 11 deletions reformer_pytorch/reformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
from product_key_memory import PKM
from reformer_pytorch.reversible import ReversibleSequence

from einops import rearrange, repeat

#constants

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work

# helper fns

def exists(val):
return val is not None

def sort_key_val(t1, t2, dim=-1):
values, indices = t1.sort(dim=dim)
t2 = t2.expand_as(t1)
Expand Down Expand Up @@ -103,6 +108,14 @@ def split_at_index(dim, index, t):

# helper classes

class Always(nn.Module):
def __init__(self, val):
super().__init__()
self.val = val

def forward(self, *args, **kwargs):
return self.val

class MatrixMultiply(nn.Module):
def __init__(self, tensor, transpose = False, normalize = False):
super().__init__()
Expand Down Expand Up @@ -249,7 +262,7 @@ def hash_vectors(self, n_buckets, vecs):
buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
return buckets

def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None, **kwargs):
def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None, pos_emb = None, **kwargs):
batch_size, seqlen, dim, device = *qk.shape, qk.device

query_len = default(query_len, seqlen)
Expand Down Expand Up @@ -279,6 +292,9 @@ def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask =
sticker = sticker.detach()
undo_sort = undo_sort.detach()

if exists(pos_emb):
qk = apply_rotary_pos_emb(qk, pos_emb)

st = (sticker % seqlen)
sqk = batched_index_select(qk, st)
sv = batched_index_select(v, st)
Expand Down Expand Up @@ -506,7 +522,7 @@ def __init__(self, dim, heads = 8, bucket_size = 64, n_hashes = 8, causal = Fals

self.callback = None

def forward(self, x, keys = None, input_mask = None, input_attn_mask = None, context_mask = None, **kwargs):
def forward(self, x, keys = None, input_mask = None, input_attn_mask = None, context_mask = None, pos_emb = None, **kwargs):
device, dtype = x.device, x.dtype
b, t, e, h, dh, m, l_h = *x.shape, self.heads, self.dim_head, self.num_mem_kv, self.n_local_attn_heads

Expand Down Expand Up @@ -556,7 +572,7 @@ def split_heads(v):
masks['input_attn_mask'] = input_attn_mask

attn_fn = self.lsh_attn if not use_full_attn else self.full_attn
partial_attn_fn = partial(attn_fn, query_len = t, **kwargs)
partial_attn_fn = partial(attn_fn, query_len = t, pos_emb = pos_emb, **kwargs)
attn_fn_in_chunks = process_inputs_chunk(partial_attn_fn, chunks = self.attn_chunks)

out, attn, buckets = attn_fn_in_chunks(qk, v, **masks)
Expand Down Expand Up @@ -623,11 +639,26 @@ def __init__(self, dim):
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, x):
t = torch.arange(x.shape[1], device=x.device).type_as(self.inv_freq)
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
def forward(self, x, seq_dim = 1):
t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
return emb[None, :, :].type_as(x)

# rotary positional embedding helpers

def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(qk, sinu_pos):
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
sin, cos = sinu_pos.unbind(dim = -2)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
qk = (qk * cos) + (rotate_every_two(qk) * sin)
return qk

# reformer lm

Expand Down Expand Up @@ -680,7 +711,7 @@ def forward(self, x, **kwargs):
return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)

class ReformerLM(nn.Module):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
super().__init__()
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
Expand All @@ -689,7 +720,12 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = No

self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)

if absolute_position_emb:
self.pos_emb = Always(0)
self.layer_pos_emb = Always(None)

if rotary_emb:
self.layer_pos_emb = FixedPositionalEmbedding(dim_head)
elif absolute_position_emb:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
elif fixed_position_emb:
self.pos_emb = FixedPositionalEmbedding(emb_dim)
Expand All @@ -711,9 +747,10 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = No

def forward(self, x, **kwargs):
x = self.token_emb(x)
x = x + self.pos_emb(x).type_as(x)
x = x + self.pos_emb(x)

layer_pos_emb = self.layer_pos_emb(x)
x = self.to_model_dim(x)
x = self.reformer(x, **kwargs)
x = self.reformer(x, pos_emb = layer_pos_emb, **kwargs)
x = self.norm(x)
return self.out(x)
21 changes: 11 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@
setup(
name = 'reformer_pytorch',
packages = find_packages(exclude=['examples', 'pretraining']),
version = '1.2.6',
version = '1.4.0',
license='MIT',
description = 'Reformer, the Efficient Transformer, Pytorch',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/reformer-pytorch',
keywords = ['transformers', 'attention', 'artificial intelligence'],
install_requires=[
'torch',
'local-attention',
'product-key-memory',
'axial-positional-embedding>=0.1.0'
'axial-positional-embedding>=0.1.0',
'einops',
'local-attention',
'product-key-memory',
'torch'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

0 comments on commit be3cfac

Please sign in to comment.