-
Notifications
You must be signed in to change notification settings - Fork 124
/
warprnnt_op.cc
191 lines (158 loc) · 7.32 KB
/
warprnnt_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#ifdef WARPRNNT_ENABLE_GPU
#define EIGEN_USE_GPU
#include <cuda.h>
#endif
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/allocator.h"
#include "rnnt.h"
REGISTER_OP("WarpRNNT")
.Input("acts: float32")
.Input("labels: int32")
.Input("input_lengths: int32")
.Input("label_lengths: int32")
.Attr("blank_label: int = 0")
.Output("costs: float32")
.Output("grads: float32");
namespace tf = tensorflow;
namespace warp_rnnt {
class WarpRNNTOpBase : public tf::OpKernel {
public:
explicit WarpRNNTOpBase(tf::OpKernelConstruction* ctx) : tf::OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_label", &blank_label_));
}
void Compute(tf::OpKernelContext* ctx) override {
// Grab the input tensors
const tf::Tensor* acts;
const tf::Tensor* labels;
const tf::Tensor* label_lengths;
const tf::Tensor* input_lengths;
OP_REQUIRES_OK(ctx, ctx->input("acts", &acts));
OP_REQUIRES_OK(ctx, ctx->input("labels", &labels));
OP_REQUIRES_OK(ctx, ctx->input("label_lengths", &label_lengths));
OP_REQUIRES_OK(ctx, ctx->input("input_lengths", &input_lengths));
OP_REQUIRES(ctx, acts->shape().dims() == 4,
tf::errors::InvalidArgument("acts is not a 4-Tensor"));
OP_REQUIRES(ctx, labels->shape().dims() == 2,
tf::errors::InvalidArgument("labels is not a 2-Tensor"));
OP_REQUIRES(ctx, tf::TensorShapeUtils::IsVector(label_lengths->shape()),
tf::errors::InvalidArgument("label_lengths is not a vector"));
OP_REQUIRES(ctx, tf::TensorShapeUtils::IsVector(input_lengths->shape()),
tf::errors::InvalidArgument("input_lengths is not a vector"));
const auto& acts_shape = acts->shape();
const auto batch_size = acts_shape.dim_size(0);
const auto max_time = acts_shape.dim_size(1);
const auto max_u = acts_shape.dim_size(2);
const auto num_classes_raw = acts_shape.dim_size(3);
auto acts_t = acts->tensor<float, 4>();
auto labels_t = labels->tensor<int32_t, 2>();
OP_REQUIRES(
ctx, tf::FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
tf::errors::InvalidArgument("num_classes cannot exceed max int"));
const auto alphabet_size = static_cast<const int>(num_classes_raw);
OP_REQUIRES(
ctx, batch_size == input_lengths->dim_size(0),
tf::errors::InvalidArgument("len(input_lengths) != batch_size. ",
"len(input_length): ", input_lengths->dim_size(0),
" batch_size: ", batch_size));
auto input_lengths_t = input_lengths->vec<int32_t>();
OP_REQUIRES(
ctx, batch_size == label_lengths->dim_size(0),
tf::errors::InvalidArgument("len(label_lengths) != batch_size. ",
"len(label_length): ", label_lengths->dim_size(0),
" batch_size: ", batch_size));
auto label_lengths_t = label_lengths->vec<int32_t>();
// TODO check that labels are in the alphabet?
// Refer to line 185, we know that
// Tensor input_lengths is in GPU, so cannot compare with CPU variable
//for (int b = 0; b < batch_size; b++) {
// OP_REQUIRES(ctx, input_lengths_t(b) <= max_time,
// tf::errors::InvalidArgument("input_lengths(", b, ") <= ", max_time));
//}
tf::Tensor* costs = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("costs", input_lengths->shape(), &costs));
auto costs_t = costs->vec<float>();
tf::Tensor* grads = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("grads", acts->shape(), &grads));
set_zero(grads);
auto grads_t = grads->tensor<float, 4>();
auto options = create_options(ctx);
options.blank_label = blank_label_;
options.maxT = max_time;
options.maxU = max_u;
size_t workspace_size_bytes;
bool use_gpu = false;
if(options.loc == RNNT_GPU) {
use_gpu = true;
}
auto warp_status = get_workspace_size(max_time,
max_u,
batch_size,
use_gpu,
&workspace_size_bytes);
OP_REQUIRES(ctx, warp_status == RNNT_STATUS_SUCCESS,
tf::errors::Internal("warp_rnnt error in get_workspace_size: ",
rnntGetStatusString(warp_status)));
auto workspace_shape = tf::TensorShape{static_cast<int64_t>(workspace_size_bytes)};
tf::Tensor workspace;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tf::DT_UINT8, workspace_shape, &workspace));
auto workspace_t = workspace.flat<uint8_t>();
// compute RNNT
warp_status = compute_rnnt_loss(acts_t.data(),
grads_t.data(),
labels_t.data(),
label_lengths_t.data(),
input_lengths_t.data(),
alphabet_size, batch_size,
costs_t.data(), workspace_t.data(), options);
OP_REQUIRES(ctx, warp_status == RNNT_STATUS_SUCCESS,
tf::errors::Internal("warp_rnnt error in compute_rnnt_loss: ",
rnntGetStatusString(warp_status)));
}
private:
int blank_label_;
virtual void set_zero(tf::Tensor* t) = 0;
virtual rnntOptions create_options(tf::OpKernelContext* ctx) = 0;
};
class WarpRNNTOpCPU : public WarpRNNTOpBase {
public:
explicit WarpRNNTOpCPU(tf::OpKernelConstruction* ctx) : WarpRNNTOpBase(ctx) {
}
private:
void set_zero(tf::Tensor* t) override {
t->flat<float>().setZero();
}
rnntOptions create_options(tf::OpKernelContext* ctx) override {
auto options = rnntOptions{};
options.loc = RNNT_CPU;
options.batch_first = true;
options.num_threads = ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
return options;
}
};
REGISTER_KERNEL_BUILDER(Name("WarpRNNT").Device(::tensorflow::DEVICE_CPU), WarpRNNTOpCPU);
#ifdef WARPRNNT_ENABLE_GPU
class WarpRNNTOpGPU : public WarpRNNTOpBase {
public:
explicit WarpRNNTOpGPU(tf::OpKernelConstruction* ctx) : WarpRNNTOpBase(ctx) {
}
private:
void set_zero(tf::Tensor* t) override {
// here is not need
// cudaMemset(t->flat<float>().data(), 0, t->NumElements()*sizeof(float));
}
rnntOptions create_options(tf::OpKernelContext* ctx) override {
auto cuda_stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
auto options = rnntOptions{};
options.loc = RNNT_GPU;
options.stream = cuda_stream;
return options;
}
};
REGISTER_KERNEL_BUILDER(Name("WarpRNNT").Device(::tensorflow::DEVICE_GPU)
.HostMemory("costs"),
WarpRNNTOpGPU);
#undef EIGEN_USE_GPU
#endif
}