Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #21 from hpcaitech/feature/example
Browse files Browse the repository at this point in the history
Feature/example
  • Loading branch information
MaruyamaAya authored Apr 12, 2022
2 parents 728f0d6 + ce29e35 commit a0608d2
Show file tree
Hide file tree
Showing 7 changed files with 1,225 additions and 1 deletion.
3 changes: 2 additions & 1 deletion energon/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .transpose_pad import transpose_pad, transpose_depad
from .scale_mask_softmax import scale_mask_softmax
from .scale_mask_softmax import scale_mask_softmax
from .layer_norm import MixedFusedLayerNorm as LayerNorm
10 changes: 10 additions & 0 deletions energon/kernel/cuda_native/csrc/compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
141 changes: 141 additions & 0 deletions energon/kernel/cuda_native/csrc/layer_norm_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */

#include "compat.h"
#include <cassert>
#include <torch/extension.h>
#include <vector>

namespace {

void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}

void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}

void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int64_t normalized_ndim = normalized_shape.size();

if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}

auto input_shape = input.sizes();
auto input_ndim = input.dim();

if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}

compute_n1_n2(input, normalized_shape, n1, n2);
}

void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
} // namespace

void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor *gamma,
at::Tensor *beta, double epsilon);

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {

CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);

at::Tensor output =
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean =
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);

cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
&gamma, &beta, epsilon);

return {output, mean, invvar};
}

void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);

std::vector<at::Tensor>
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {

CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);

at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);

cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);

return {grad_input, grad_gamma, grad_beta};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}
Loading

0 comments on commit a0608d2

Please sign in to comment.