-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
67 lines (57 loc) · 1.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Train the network.
"""
import torch.nn as nn
from torch import optim
from data import read_data, tensors_from_pair
from models.attn_decoder import AttnDecoder
from models.encoder import Encoder
from models.ptr_decoder import PtrDecoder
from train import train
from utils import device, set_max_length
def weight_init(module):
"""
Initialize weights of <module>. Applied recursivly over model weights via .apply()
"""
for parameter in module.parameters():
nn.init.uniform_(parameter, -0.08, 0.08)
def run():
"""
Run the experiment.
"""
name = "train"
is_ptr = True
hidden_dim = embedding_dim = 256
n_epochs = 1
grad_clip = 2
teacher_force_ratio = 0.5
optimizer = optim.Adam
optimizer_params = {}
max_val, max_length, pairs = read_data(name)
set_max_length(max_length)
training_pairs = [tensors_from_pair(pair) for pair in pairs]
data_dim = max_val + 1
encoder = Encoder(input_dim=data_dim,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim).to(device)
if is_ptr:
decoder = PtrDecoder(output_dim=data_dim,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim).to(device)
else:
decoder = AttnDecoder(output_dim=data_dim,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim).to(device)
train(encoder=encoder,
decoder=decoder,
optim=optimizer,
optim_params=optimizer_params,
weight_init=weight_init,
grad_clip=grad_clip,
is_ptr=True,
training_pairs=training_pairs,
n_epochs=n_epochs,
teacher_force_ratio=teacher_force_ratio,
print_every=50,
plot_every=50,
save_every=100)