Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Feature/example #21

Merged
merged 2 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion energon/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .transpose_pad import transpose_pad, transpose_depad
from .scale_mask_softmax import scale_mask_softmax
from .scale_mask_softmax import scale_mask_softmax
from .layer_norm import MixedFusedLayerNorm as LayerNorm
10 changes: 10 additions & 0 deletions energon/kernel/cuda_native/csrc/compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
141 changes: 141 additions & 0 deletions energon/kernel/cuda_native/csrc/layer_norm_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */

#include "compat.h"
#include <cassert>
#include <torch/extension.h>
#include <vector>

namespace {

void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}

void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}

void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int64_t normalized_ndim = normalized_shape.size();

if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}

auto input_shape = input.sizes();
auto input_ndim = input.dim();

if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}

compute_n1_n2(input, normalized_shape, n1, n2);
}

void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
} // namespace

void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor *gamma,
at::Tensor *beta, double epsilon);

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {

CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);

at::Tensor output =
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean =
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);

cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
&gamma, &beta, epsilon);

return {output, mean, invvar};
}

void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);

std::vector<at::Tensor>
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {

CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);

at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);

cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);

return {grad_input, grad_gamma, grad_beta};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}
Loading