-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
53 lines (31 loc) · 1.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
def train_model(model, dataloader, num_epochs ,device):
criterion=nn.NLLLoss()
optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
print("Starting to train..")
for epoch in range(num_epochs):
count=0
# For each batch in the dataloader
for images, labels in dataloader:
model.zero_grad()
output= model(images)
#labels=one_hot = torch.nn.functional.one_hot(labels, 3)
# print("Output=",output.size())
# print("Labels",labels.size())
#print(output)
loss=criterion(torch.log(output),labels)
loss.backward()
optimizer.step()
#if count % 50 == 0:
print("Epoches:[",epoch,"/",num_epochs,"]", "\tBatch:[",count,"/", len(dataloader),"]","\tCategorial cross entropy loss=",loss.item())
count=count+1
return model