Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet training skips samples at the begging of each epoch #239

Merged
merged 9 commits into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 71 additions & 50 deletions example/resnet/resnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package main

import (
"flag"
"fmt"
"io"
"log"
"math"
"math/rand"
"os"
"time"

Expand All @@ -16,7 +16,9 @@ import (
"github.com/wangkuiyi/gotorch/vision/transforms"
)

var trainSamples = 1281167
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved
var device torch.Device
var logInterval = 10
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved

func max(array []int64) int64 {
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved
max := array[0]
Expand Down Expand Up @@ -63,61 +65,29 @@ func accuracy(output, target torch.Tensor, topk []int64) []float32 {
return res
}

func imageNetLoader(tarFile string, batchSize int) (*datasets.ImageNetLoader, error) {
f, e := os.Open(tarFile)
if e != nil {
panic(e)
}
vocab, e := datasets.BuildLabelVocabulary(f)
fmt.Println("building label vocabulary done.")
if e != nil {
return nil, e
}
if _, e := f.Seek(0, io.SeekStart); e != nil {
return nil, e
}
func imageNetLoader(r io.Reader, vocab map[string]int64, batchSize int, skipSamples int) (*datasets.ImageNetLoader, error) {
trans := transforms.Compose(
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([]float64{0.485, 0.456, 0.406}, []float64{0.229, 0.224, 0.225}))

loader, e := datasets.ImageNet(f, vocab, trans, batchSize)
loader, e := datasets.ImageNet(r, vocab, trans, batchSize, skipSamples)
if e != nil {
return nil, e
}
return loader, nil
}

func train(model *models.ResnetModule, opt torch.Optimizer, batchSize int, device torch.Device, tarFile string) {
model.Train(true)
loader, e := imageNetLoader(tarFile, batchSize)
if e != nil {
panic(e)
}
batchIdx := 1
startTime := time.Now()
for loader.Scan() {
image, target := loader.Minibatch()
image = image.To(device, torch.Float)
target = target.To(device, torch.Long)
output := model.Forward(image)
loss := F.CrossEntropy(output, target, torch.Tensor{}, -100, "mean")

acc := accuracy(output, target, []int64{1, 5})
acc1 := acc[0]
acc5 := acc[1]
if batchIdx%100 == 0 {
throughput := float64(100*batchSize) / time.Since(startTime).Seconds()
fmt.Printf("batch: %d, loss: %f, acc1 :%f, acc5: %f, throughput: %.2f samples/sec\n", batchIdx, loss.Item(), acc1, acc5, throughput)
}

opt.ZeroGrad()
loss.Backward()
opt.Step()
batchIdx++
}
torch.FinishGC()
func trainOneBatch(image, target torch.Tensor, model *models.ResnetModule, opt torch.Optimizer) (float32, float32, float32) {
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved
output := model.Forward(image)
loss := F.CrossEntropy(output, target, torch.Tensor{}, -100, "mean")
acc := accuracy(output, target, []int64{1, 5})
acc1 := acc[0]
acc5 := acc[1]
loss.Backward()
opt.Step()
return loss.Item(), acc1, acc5
}

func main() {
Expand All @@ -132,7 +102,6 @@ func main() {
lr := 0.1
momentum := 0.9
weightDecay := 1e-4

if torch.IsCUDAAvailable() {
log.Println("CUDA is valid")
device = torch.NewDevice("cuda")
Expand All @@ -143,14 +112,66 @@ func main() {

model := models.Resnet50()
model.To(device)
model.Train(true)

optimizer := torch.SGD(lr, momentum, 0, weightDecay, false)
optimizer.AddParameters(model.Parameters())

start := time.Now()
for epoch := 0; epoch < epochs; epoch++ {
adjustLearningRate(optimizer, epoch, lr)
train(model, optimizer, batchSize, device, *trainFile)
f, e := os.Open(*trainFile)
if e != nil {
log.Fatal(e)
}
// build label vocabulary
vocab, e := datasets.BuildLabelVocabulary(f)
if e != nil {
log.Fatal(e)
}
log.Print("building label vocabulary done.")

batchs := itersEachEpoch(trainSamples, batchSize)
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved
batch := 0
epoch := 0
adjustLearningRate(optimizer, epoch, lr)
skipSamples := 0
startTime := time.Now()
for epoch < epochs {
// seek to 0 of the file reader, and create an ImageNet loader
if _, e := f.Seek(0, io.SeekStart); e != nil {
log.Fatal(e)
}
loader, e := imageNetLoader(f, vocab, batchSize, skipSamples)
if e != nil {
panic(e)
}
for loader.Scan() {
batch++
image, target := loader.Minibatch()
loss, acc1, acc5 := trainOneBatch(image.To(device, torch.Float), target.To(device, torch.Long), model, optimizer)
if batch%logInterval == 0 {
throughput := float64(batch/logInterval) / time.Since(startTime).Seconds()
log.Printf("Epoch: %d, Batch: %d, loss:%f, acc1: %f, acc5:%f, throughput: %f samples/secs",
epoch, batch, loss, acc1, acc5, throughput)
startTime = time.Now()
}
if batch == batchs {
break
}
}
if batch == batchs {
// go to next epoch
epoch++
adjustLearningRate(optimizer, epoch, lr)
batch = 0
skipSamples = rand.Intn(batchSize)
log.Printf("skip %d samples at the next epoch", skipSamples)
}
}
}

func itersEachEpoch(samples, batchSize int) int {
itersEachEpoch := trainSamples / batchSize
Yancey1989 marked this conversation as resolved.
Show resolved Hide resolved
if trainSamples%batchSize != 0 {
itersEachEpoch++
}
fmt.Println(time.Since(start).Seconds())
return itersEachEpoch
}
36 changes: 22 additions & 14 deletions vision/datasets/imagenet.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,29 @@ type sample struct {
// img, target := imageNet.Minibatch()
// }
type ImageNetLoader struct {
mbSize int
tr *tar.Reader
vocab map[string]int64 // the vocabulary of labels.
err error
inputs []torch.Tensor // inputs and labels form a minibatch.
labels []int64
trans *transforms.ComposeTransformer
mbSize int
skipSamples int
tr *tar.Reader
vocab map[string]int64 // the vocabulary of labels.
err error
inputs []torch.Tensor // inputs and labels form a minibatch.
labels []int64
trans *transforms.ComposeTransformer
}

// ImageNet returns ImageNet dataDataLoader
func ImageNet(r io.Reader, vocab map[string]int64, trans *transforms.ComposeTransformer, mbSize int) (*ImageNetLoader, error) {
func ImageNet(r io.Reader, vocab map[string]int64, trans *transforms.ComposeTransformer, mbSize, skipSamples int) (*ImageNetLoader, error) {
tgz, e := newTarGzReader(r)
if e != nil {
return nil, e
}
return &ImageNetLoader{
mbSize: mbSize,
tr: tgz,
err: nil,
vocab: vocab,
trans: trans,
mbSize: mbSize,
tr: tgz,
err: nil,
vocab: vocab,
trans: trans,
skipSamples: skipSamples,
}, nil
}

Expand All @@ -64,6 +66,7 @@ func (p *ImageNetLoader) tensorGC() {
}

func (p *ImageNetLoader) retreiveMinibatch() {
iter := 0
for {
hdr, err := p.tr.Next()
if err != nil {
Expand All @@ -73,6 +76,12 @@ func (p *ImageNetLoader) retreiveMinibatch() {
if !strings.HasSuffix(strings.ToUpper(hdr.Name), "JPEG") {
continue
}
if iter < p.skipSamples {
iter++
continue
}
// only skip the head samplers
p.skipSamples = 0

label := p.vocab[filepath.Base(filepath.Dir(hdr.Name))]
p.labels = append(p.labels, label)
Expand Down Expand Up @@ -147,7 +156,6 @@ func BuildLabelVocabulary(reader io.Reader) (map[string]int64, error) {
}
}
}
return classToIdx, nil
}

func newTarGzReader(r io.Reader) (*tar.Reader, error) {
Expand Down
25 changes: 24 additions & 1 deletion vision/datasets/imagenet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestImgNetLoader(t *testing.T) {
a.Equal(3, len(vocab))

trans := transforms.Compose(transforms.RandomCrop(224, 224), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor())
loader, e := datasets.ImageNet(bytes.NewReader(tgz.Bytes()), vocab, trans, 2)
loader, e := datasets.ImageNet(bytes.NewReader(tgz.Bytes()), vocab, trans, 2, 0)
a.NoError(e)
{
// the first iteration
Expand All @@ -47,3 +47,26 @@ func TestBuildLabelVocabularyFail(t *testing.T) {
_, e := datasets.BuildLabelVocabulary(strings.NewReader("some string"))
assert.Error(t, e)
}
func TestImgNetLoaderSkipSamples(t *testing.T) {
a := assert.New(t)
var tgz bytes.Buffer
synthesizeImages(&tgz)

vocab, e := datasets.BuildLabelVocabulary(bytes.NewReader(tgz.Bytes()))
a.NoError(e)
a.Equal(3, len(vocab))

trans := transforms.Compose(transforms.RandomCrop(224, 224), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor())
loader, e := datasets.ImageNet(bytes.NewReader(tgz.Bytes()), vocab, trans, 2, 2)
a.NoError(e)
{
// the first iteration
a.True(loader.Scan())
data, label := loader.Minibatch()
a.Equal([]int64{1, 3, 224, 224}, data.Shape())
a.Equal([]int64{1}, label.Shape())
a.NoError(loader.Err())
}
// no more data
a.False(loader.Scan())
}