Skip to content

Commit

Permalink
update checking logic for scale and zp op (#83)
Browse files Browse the repository at this point in the history
For fake_quantize_ops, `scale` and `zero_point` operands could be
lowered by `TorchToTcp` already, so we need to relax the checking
condition here (just need to make sure the type and shape are valid).
  • Loading branch information
zezhang committed Jul 16, 2024
1 parent a4fba88 commit 05798b6
Showing 1 changed file with 19 additions and 57 deletions.
76 changes: 19 additions & 57 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,39 +154,21 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp
helper.addIntAttr("quant_max", op.getQuantMax());

// scale
auto scaleOp = op.getScale().getDefiningOp();
if (!scaleOp)
return rewriter.notifyMatchFailure(op, "Missing scale operation");
auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
if (!scaleTensor)
return rewriter.notifyMatchFailure(
op, "Scale operation is not ValueTensorLiteralOp");
auto scaleElements =
dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr());
// scale should be a [1] tensor.
if (!scaleElements || scaleElements.getNumElements() != 1)
auto scaleTy = adaptor.getScale().getType().dyn_cast<RankedTensorType>();
if (!scaleTy || scaleTy.getShape().size() != 1 ||
scaleTy.getNumElements() != 1)
// scale should be a [1] tensor.
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
helper.addOperand("scale", adaptor.getScale());

// zero_point
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
if (!zeroPointOp)
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
if (auto zeroPointTensor =
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
auto zeroPointElements =
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
auto zeroPointTy =
adaptor.getZeroPoint().getType().dyn_cast<RankedTensorType>();
if (!zeroPointTy || zeroPointTy.getShape().size() != 1 ||
zeroPointTy.getNumElements() != scaleTy.getNumElements())
// zero_point should be a [1] tensor.
if (!zeroPointElements || zeroPointElements.getNumElements() != 1)
return rewriter.notifyMatchFailure(
op, "Unsupported zero point type or size");
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
// zero like operations are converted through torch-to-tcp
return rewriter.notifyMatchFailure(
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
"operation");
}
return rewriter.notifyMatchFailure(op,
"Unsupported zero point type or size");
helper.addOperand("zero_point", adaptor.getZeroPoint());

return helper.replace();
Expand All @@ -209,40 +191,20 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
helper.addIntAttr("quant_max", op.getQuantMax());

// scale
auto scaleOp = op.getScale().getDefiningOp();
if (!scaleOp)
return rewriter.notifyMatchFailure(op, "Missing scale operation");
auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
if (!scaleTensor)
return rewriter.notifyMatchFailure(
op, "Scale operation is not ValueTensorLiteralOp");
auto scaleElements =
dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr());
// scale should be a [C] tensor.
if (!scaleElements || scaleElements.getType().getShape().size() != 1)
auto scaleTy = adaptor.getScale().getType().dyn_cast<RankedTensorType>();
if (!scaleTy || scaleTy.getShape().size() != 1)
// scale should be a [C] tensor.
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
helper.addOperand("scale", adaptor.getScale());

// zero_point
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
if (!zeroPointOp)
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
if (auto zeroPointTensor =
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
auto zeroPointElements =
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
auto zeroPointTy =
adaptor.getZeroPoint().getType().dyn_cast<RankedTensorType>();
if (!zeroPointTy || zeroPointTy.getShape().size() != 1 ||
zeroPointTy.getNumElements() != scaleTy.getNumElements())
// zero_point should be a [C] tensor.
if (!zeroPointElements ||
zeroPointElements.getType().getShape().size() != 1)
return rewriter.notifyMatchFailure(
op, "Unsupported zero point type or size");
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
// zero like operations are converted through torch-to-tcp
return rewriter.notifyMatchFailure(
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
"operation");
}
return rewriter.notifyMatchFailure(op,
"Unsupported zero point type or size");
helper.addOperand("zero_point", adaptor.getZeroPoint());

return helper.replace();
Expand Down

0 comments on commit 05798b6

Please sign in to comment.