-
Notifications
You must be signed in to change notification settings - Fork 35
Comparisons
Yi Wang edited this page Jul 29, 2020
·
1 revision
Here is a typical PyTorch program in four different languages:
- The Python version comes from the official tutorial.
- The C++ version calls the ATen C library and Torch's
csrc
C++ library. Thanks to Jia-Kai Liu, a tech lead of PyTorch, for teaching me everything about the C/C++ core of PyTorch. Please follow instructions in https://github.com/wangkuiyi/cxxtorch to run this program. - The Go version calls imaginary Go binding of ATen and
csrc
. - The Go+ version is also imaginary.
C++ | Go |
#include <iostream>
#include "torch/script.h"
#include "torch/optim.h"
int main() {
int N = 64, D_in = 1000, H = 100, D_out = 10;
double learning_rate = 1e-3;
auto x = torch::randn({N, D_in},
at::TensorOptions().requires_grad(false));
auto y = torch::randn({N, D_out},
at::TensorOptions().requires_grad(false));
// The Adam optimizer wants parameters in a std::vector.
std::vector<at::Tensor> params = {
torch::randn({D_in, H},
at::TensorOptions().requires_grad(true)),
torch::randn({H, D_out},
at::TensorOptions().requires_grad(true))};
// Build the optimizer.
torch::optim::Adam adam(params,
torch::optim::AdamOptions(learning_rate));
// Make quick references for using in the forward pass.
const at::Tensor & w1 = adam.parameters()[0];
const at::Tensor & w2 = adam.parameters()[1];
for (int i = 0; i < 500; ++i) {
auto y_pred = at::mm(at::clamp(at::mm(x, w1), 0), w2);
auto loss = at::sum(at::pow(at::sub(y_pred, y), 2));
if ((i % 100) == 99) {
std::cout << "loss = " << loss << std::endl;
}
adam.zero_grad();
loss.backward();
adam.step();
}
return 0;
} |
package main
import (
"fmt"
at "github.com/gotorch/gotorch/aten"
"github.com/gotorch/gotorch/torch"
"github.com/gotorch/gotorch/torch/optim"
)
func main() {
N, D_in, H, D_out := 64, 1000, 100, 10
learning_rate := 1e-3
x := torch.RandN([]int{N, Din},
at.TensorOptions().RequiresGrad(false))
y := torch.RandN([]int{N, Dout},
at.TensorOptions().RequiresGrad(false))
params := []at.Tensor{
torch.RandN([]int{Din, H},
at.TensorOptions().RequiresGrad(true)),
torch.RandN([]int{H, Dout},
at.TensorOptions().RequiresGrad(true)),
}
adam := optim.NewAdam(params, optim.AdamOptions(learning_rate))
w1 := adam.parameters()[0]
w2 := adam.parameters()[1]
for i := 0; i < 500; i++ {
y_pred := at.Sum(at.Clamp(at.MM(x, w1), 0), w2)
loss := at.Sum(at.Pow(at.Sub(y_pred, y), 2))
if i%100 == 0 {
fmt.Println("loss = ", loss)
}
adam.ZeroGrad()
loss.Backward()
adam.Step()
}
} |
Go+ | Python |
package main
import (
"fmt"
"github.com/gotorch/gotorch/at"
"github.com/gotorch/gotorch/torch"
"github.com/gotorch/gotorch/torch/optim"
)
func main() {
N, D_in, H, D_out := 64, 1000, 100, 10
x := torch.RandN(N, Din, requires_grad=False)
y := torch.RandN(N, Dout, requires_grad=False)
w1 := torch.randn(D_in, H, requires_grad=True)
w2 := torch.randn(H, D_out, requires_grad=True)
learning_rate := 1e-3
adam := optim.NewAdam([w1, w2], lr=learning_rate)
for i := 0; i < 500; i++ {
y_pred := at.Sum(at.Clamp(at.MM(x, w1), 0), w2)
loss := at.Sum(at.Pow(at.Sub(y_pred, y), 2))
if i%100 == 0 {
fmt.Println("loss = ", loss)
}
adam.ZeroGrad()
loss.Backward()
adam.Step()
}
} |
import torch
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, requires_grad=False)
y = torch.randn(N, D_out, requires_grad=False)
w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)
learning_rate = 1e-3
adam = torch.optim.Adam([w1, w2], lr=learning_rate)
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
adam.zero_grad()
loss.backward()
adam.step() |
From the above four programs, we can see
- The Go binding could be as effective as the C/C++ API, in terms of the number of lines of source code.
- If we want the Go+ version as short/concise as the Python version, the primary requirement to the Go+ transpiler is to support named function parameters. For example,
x := torch.RandN(N, Din, requires_grad=False)
.