Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mnist doc and code #371

Merged
merged 3 commits into from
Oct 25, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/shuffle_tarball.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ file content. A directory does not have content, but only headers. This format
allows us to read sequentially from `.tar` or `.tar.gz` file for image files
without causing frequent mechanical movements in hard drives.

## Caution
gnu-tar should be used on macOS instead of bsdtar.

## Shuffling

It is critical in deep learning to ensure that each minibatch or consecutive
Expand Down Expand Up @@ -68,7 +71,7 @@ divide-and-merge strategy -- `tarball_divide` and `tarball_merge`.
To install them, we need the Go compiler and run the following commands.

```bash
go get github.com/wangkuiyi/gotorch/tools/...
go get github.com/wangkuiyi/gotorch/tool/...
```

We can then find the executable files in `$GOPATH/bin`.
Expand Down
6 changes: 5 additions & 1 deletion doc/shuffle_tarball_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
一个描述头部和紧随其后的文件内容。特别地,目录只包含描述部分而没有文件内容。这种结构允许我们以
顺序方式去读取包含大量图像文件的 `.tar` 或 `.tar.gz` 包,而无需频繁移动磁头去寻找文件。

## 注意
在 macOS 上应该使用 gnu-tar 代替 bsdtar

## 打乱顺序(Shuffling)

在深度学习中,保证输入模型的每一批数据(minibatch)中包含不同的标签是至关重要的。这个特性称为
Expand Down Expand Up @@ -53,7 +56,7 @@ drwxr-x--- 0 myleott myleott 0 Dec 10 2015 mnist_png/testing/2/
`tarball_divide` 和 `tarball_merge`。 我们可以通过以下命令来安装它们:

```bash
go get github.com/wangkuiyi/gotorch/tools/...
go get github.com/wangkuiyi/gotorch/tool/...
```

运行上述命令后,我们可以在 `$GOPATH/bin` 中找到这两个工具的二进制文件。
Expand Down Expand Up @@ -105,6 +108,7 @@ go get github.com/wangkuiyi/gotorch/tools/...
rm [0-9].tar.gz
tarball_divide mnist_png_testing.tar.gz
tarball_merge -out=mnist_png_testing_shuffled.tar.gz [0-9].tar.gz
rm [0-9].tar.gz
tar tvf mnist_png_testing_shuffled.tar.gz | grep \.png$ | wc -l
tar tvf mnist_png_testing.tar.gz | grep \.png$ | wc -l
```
Expand Down
14 changes: 2 additions & 12 deletions example/mnist/mnist.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/gob"
"flag"
"fmt"
"image"
"log"
"os"
"path/filepath"
Expand All @@ -17,6 +16,7 @@ import (
"github.com/wangkuiyi/gotorch/vision/imageloader"
"github.com/wangkuiyi/gotorch/vision/models"
"github.com/wangkuiyi/gotorch/vision/transforms"
"gocv.io/x/gocv"
)

var device torch.Device
Expand Down Expand Up @@ -169,17 +169,7 @@ func loadModel(modelFn string) *models.MLPModule {
}

func predictFile(fn string, m *models.MLPModule) {
f, e := os.Open(fn)
if e != nil {
log.Fatal(e)
}
defer f.Close()

img, _, e := image.Decode(f)
if e != nil {
log.Fatalf("Cannot decode input image: %v", e)
}

img := gocv.IMRead(fn, gocv.IMReadGrayScale)
t := transforms.ToTensor().Run(img)
n := transforms.Normalize([]float32{0.1307}, []float32{0.3081}).Run(t)
fmt.Println(m.Forward(n).Argmax().Item())
Expand Down