From 2cbc36bb280c0a6de46d838baad11a533a015fc3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 25 Aug 2021 04:57:32 -0700 Subject: [PATCH] update flag for turning on axial positional embedding and update readme --- README.md | 7 +++++-- reformer_pytorch/reformer_pytorch.py | 10 +++++----- setup.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 59fd33f..1e854df 100644 --- a/README.md +++ b/README.md @@ -154,8 +154,11 @@ y = model(x, keys = c, input_mask = i_mask, context_mask = c_mask) ## Positional Embeddings -Aran has informed me that the Reformer team used axial position embeddings with great results on longer sequences. I tested it out and indeed it works very well! So well in fact that I have decided to make this the default. You can adjust the shape and dimension of the axial embeddings by following the instructions below. +The default positional embedding uses rotary embeddings. +However, Aran has informed me that the Reformer team used axial position embeddings with great results on longer sequences. + +You can turn on axial positional embedding and adjust the shape and dimension of the axial embeddings by following the instructions below. ```python import torch @@ -169,8 +172,8 @@ model = ReformerLM( ff_chunks = 8, attn_chunks = 2, causal = True, + axial_position_emb = True, # set this to True axial_position_shape = (128, 64), # the shape must multiply up to the max_seq_len (128 x 64 = 8192) - axial_position_dims = (512, 512) # the dims must sum up to the model dimensions (512 + 512 = 1024) ) x = torch.randint(0, 20000, (1, 8192)).long() diff --git a/reformer_pytorch/reformer_pytorch.py b/reformer_pytorch/reformer_pytorch.py index 4b7dd32..a3e0d4f 100644 --- a/reformer_pytorch/reformer_pytorch.py +++ b/reformer_pytorch/reformer_pytorch.py @@ -713,7 +713,7 @@ def forward(self, x, **kwargs): return torch.stack(x.chunk(2, dim=-1)).mean(dim=0) class ReformerLM(nn.Module): - def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): + def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len @@ -725,15 +725,15 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64 self.pos_emb = Always(0) self.layer_pos_emb = Always(None) - if rotary_emb: - self.layer_pos_emb = FixedPositionalEmbedding(dim_head) + if axial_position_emb: + axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size)) + self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape) elif absolute_position_emb: self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) elif fixed_position_emb: self.pos_emb = FixedPositionalEmbedding(emb_dim) else: - axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size)) - self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape) + self.layer_pos_emb = FixedPositionalEmbedding(dim_head) self.reformer = Reformer(dim, depth, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) self.norm = nn.LayerNorm(dim) diff --git a/setup.py b/setup.py index 638f844..65ce35b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'reformer_pytorch', packages = find_packages(exclude=['examples', 'pretraining']), - version = '1.4.2', + version = '1.4.3', license='MIT', description = 'Reformer, the Efficient Transformer, Pytorch', author = 'Phil Wang',