Skip to content

Commit

Permalink
Merged PR 11831: Change the weight matrix quantization to use 7-bit m…
Browse files Browse the repository at this point in the history
…in/max quantization to avoid overflow

1. Change the weight matrix quantization to use 7-bit min/max quantization
-> This resolves all the overflow issue, because weight and activations are quantized by min/max range.
2. Clip fp16 quantization to avoid overflow
3. Fix windows build errors (cmake options, vcproj file)
4. int8 pack model (encoder -> fp16)
  • Loading branch information
ykim362 authored and ugermann committed May 20, 2020
1 parent 9cd1623 commit 63006db
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/tensors/cpu/fbgemm/expression_graph_packable.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace marian {
// This requires some more changes, but we temporarily do this just by name ("_W") of the weights.
// And, this introduces a low level packed_gemm.h apis interact with high level graph class.
// So, we make a subclass of ExpressionGraph and put those immature codes in this class.
// We will improve this in the near future.
// We will improve this in the near future.
class ExpressionGraphPackable : public ExpressionGraph {
public:
ExpressionGraphPackable()
Expand All @@ -36,10 +36,11 @@ class ExpressionGraphPackable : public ExpressionGraph {

// save as packed format
// @TODO Hardcoded to find packable weights
// int8 - all the weights used for affine op and dot op
// fp16 - all the weights used for affine op
// int8 - quantize decoder only for better quality, all the weights used for affine op and dot op (int8)
// fp16 - all the weights used for affine op (fp16)
if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512)
&& (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) {
&& (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)
&& pName.find("encoder") == std::string::npos) {
#if USE_FBGEMM
using namespace marian::cpu::variant;
// packing information - size
Expand Down Expand Up @@ -84,8 +85,10 @@ class ExpressionGraphPackable : public ExpressionGraph {
#else
ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType);
#endif
// fp16 quantization option
} else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) {
// fp16 quantization option + encoders for int8 quantized models
} else if ((gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3)
|| ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512)
&& (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2))) {
#if USE_FBGEMM
using namespace marian::cpu::variant;

Expand Down Expand Up @@ -153,4 +156,4 @@ class ExpressionGraphPackable : public ExpressionGraph {
}
};

} // namespace marian
} // namespace marian

0 comments on commit 63006db

Please sign in to comment.