Skip to content

Comparisons

Yi Wang edited this page Jul 29, 2020 · 1 revision

Here is a typical PyTorch program in four different languages:

  • The Python version comes from the official tutorial.
  • The C++ version calls the ATen C library and Torch's csrc C++ library. Thanks to Jia-Kai Liu, a tech lead of PyTorch, for teaching me everything about the C/C++ core of PyTorch. Please follow instructions in https://github.com/wangkuiyi/cxxtorch to run this program.
  • The Go version calls imaginary Go binding of ATen and csrc.
  • The Go+ version is also imaginary.
C++ Go
#include <iostream>

#include "torch/script.h"
#include "torch/optim.h"

int main() {
  int N = 64, D_in = 1000, H = 100, D_out = 10;
  double learning_rate = 1e-3;

  auto x = torch::randn({N, D_in},
                        at::TensorOptions().requires_grad(false));
  auto y = torch::randn({N, D_out},
                        at::TensorOptions().requires_grad(false));

  // The Adam optimizer wants parameters in a std::vector.
  std::vector<at::Tensor> params = {
    torch::randn({D_in, H},
                 at::TensorOptions().requires_grad(true)),
    torch::randn({H, D_out},
                 at::TensorOptions().requires_grad(true))};

  // Build the optimizer.
  torch::optim::Adam adam(params,
                          torch::optim::AdamOptions(learning_rate));

  // Make quick references for using in the forward pass.
  const at::Tensor & w1 = adam.parameters()[0];
  const at::Tensor & w2 = adam.parameters()[1];

  for (int i = 0; i < 500; ++i) {
    auto y_pred = at::mm(at::clamp(at::mm(x, w1), 0), w2);
    auto loss = at::sum(at::pow(at::sub(y_pred, y), 2));

    if ((i % 100) == 99) {
      std::cout << "loss = " << loss << std::endl;
    }

    adam.zero_grad();
    loss.backward();
    adam.step();
  }
  return 0;
}
package main

import (
	"fmt"

	at "github.com/gotorch/gotorch/aten"
	"github.com/gotorch/gotorch/torch"
	"github.com/gotorch/gotorch/torch/optim"
)

func main() {
	N, D_in, H, D_out := 64, 1000, 100, 10
	learning_rate := 1e-3

	x := torch.RandN([]int{N, Din},
		at.TensorOptions().RequiresGrad(false))
	y := torch.RandN([]int{N, Dout},
		at.TensorOptions().RequiresGrad(false))

	params := []at.Tensor{
		torch.RandN([]int{Din, H},
			at.TensorOptions().RequiresGrad(true)),
		torch.RandN([]int{H, Dout},
			at.TensorOptions().RequiresGrad(true)),
	}

	adam := optim.NewAdam(params, optim.AdamOptions(learning_rate))

	w1 := adam.parameters()[0]
	w2 := adam.parameters()[1]

	for i := 0; i < 500; i++ {
		y_pred := at.Sum(at.Clamp(at.MM(x, w1), 0), w2)
		loss := at.Sum(at.Pow(at.Sub(y_pred, y), 2))

		if i%100 == 0 {
			fmt.Println("loss = ", loss)
		}

		adam.ZeroGrad()
		loss.Backward()
		adam.Step()
	}
}
Go+ Python
package main

import (
	"fmt"

	"github.com/gotorch/gotorch/at"
	"github.com/gotorch/gotorch/torch"
	"github.com/gotorch/gotorch/torch/optim"
)

func main() {
	N, D_in, H, D_out := 64, 1000, 100, 10

	x := torch.RandN(N, Din, requires_grad=False)
	y := torch.RandN(N, Dout, requires_grad=False)

	w1 := torch.randn(D_in, H, requires_grad=True)
	w2 := torch.randn(H, D_out, requires_grad=True)

	learning_rate := 1e-3
	adam := optim.NewAdam([w1, w2], lr=learning_rate)

	for i := 0; i < 500; i++ {
		y_pred := at.Sum(at.Clamp(at.MM(x, w1), 0), w2)
		loss := at.Sum(at.Pow(at.Sub(y_pred, y), 2))

		if i%100 == 0 {
			fmt.Println("loss = ", loss)
		}

		adam.ZeroGrad()
		loss.Backward()
		adam.Step()
	}
}
import torch

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, requires_grad=False)
y = torch.randn(N, D_out, requires_grad=False)

w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)

learning_rate = 1e-3
adam = torch.optim.Adam([w1, w2], lr=learning_rate)

for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    loss = (y_pred - y).pow(2).sum()

    if t % 100 == 99:
        print(t, loss.item())

    adam.zero_grad()
    loss.backward()
    adam.step()

From the above four programs, we can see

  1. The Go binding could be as effective as the C/C++ API, in terms of the number of lines of source code.
  2. If we want the Go+ version as short/concise as the Python version, the primary requirement to the Go+ transpiler is to support named function parameters. For example, x := torch.RandN(N, Din, requires_grad=False).
Clone this wiki locally