Skip to content

Commit

Permalink
[Conv2DTransposed] Fix wrong shape check
Browse files Browse the repository at this point in the history
The default shape format of TVM is `N x Cx iH x iW` for input and `O x I x kH x kW` for weight, a proper shape for Conv2dTransposed should 

* input: (batch, in_channels, iH, iW)
* weight: (out_channels, in_channels // groups, kH, kW)

Thus the original checking
```
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
```  
is wrong. The proper comparison dimension should be `wshape[1]` rather than `wshape[0]`.

Besides, the name for debug is also not correct. All logging information are using `conv2d` rather than `conv2d_transposed`, which is confusing.
  • Loading branch information
Lyken17 committed Nov 6, 2021
1 parent 7b58e16 commit 4fdd06e
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1099,16 +1099,20 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
// check the size
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< "Conv2DTransposed: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< "Conv2DTransposed: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]))
<< "Conv2DTransposed: data.shape[1] // groups != weight.shape[1], "
<< " data.shape= " << Array<IndexExpr>(dshape_nchw)
<< " groups= " << param->groups
<< " weight.shape= " << Array<IndexExpr>(wshape);
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down

0 comments on commit 4fdd06e

Please sign in to comment.