Skip to content

Commit

Permalink
[FP16] Implement conversion operations. (#6974)
Browse files Browse the repository at this point in the history
Note: FP16 is a little different from F32/F64 since it can't represent
the full 2^16 integer range. 65504 is the max whole integer. This leads
to some slightly strange behavior when converting integers greater than
65504 since they become infinity.


Specified at
https://github.com/WebAssembly/half-precision/blob/main/proposals/half-precision/Overview.md
  • Loading branch information
brendandahl committed Sep 26, 2024
1 parent 3856a2d commit c3a71ff
Show file tree
Hide file tree
Showing 19 changed files with 303 additions and 6 deletions.
4 changes: 4 additions & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,10 @@
("i32x4.trunc_sat_f64x2_u_zero", "makeUnary(UnaryOp::TruncSatZeroUVecF64x2ToVecI32x4)"),
("f32x4.demote_f64x2_zero", "makeUnary(UnaryOp::DemoteZeroVecF64x2ToVecF32x4)"),
("f64x2.promote_low_f32x4", "makeUnary(UnaryOp::PromoteLowVecF32x4ToVecF64x2)"),
("i16x8.trunc_sat_f16x8_s", "makeUnary(UnaryOp::TruncSatSVecF16x8ToVecI16x8)"),
("i16x8.trunc_sat_f16x8_u", "makeUnary(UnaryOp::TruncSatUVecF16x8ToVecI16x8)"),
("f16x8.convert_i16x8_s", "makeUnary(UnaryOp::ConvertSVecI16x8ToVecF16x8)"),
("f16x8.convert_i16x8_u", "makeUnary(UnaryOp::ConvertUVecI16x8ToVecF16x8)"),

# relaxed SIMD ops
("i8x16.relaxed_swizzle", "makeBinary(BinaryOp::RelaxedSwizzleVecI8x16)"),
Expand Down
49 changes: 44 additions & 5 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,34 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 'c':
if (op == "f16x8.ceil"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::CeilVecF16x8));
return Ok{};
case 'c': {
switch (buf[7]) {
case 'e':
if (op == "f16x8.ceil"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::CeilVecF16x8));
return Ok{};
}
goto parse_error;
case 'o': {
switch (buf[20]) {
case 's':
if (op == "f16x8.convert_i16x8_s"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::ConvertSVecI16x8ToVecF16x8));
return Ok{};
}
goto parse_error;
case 'u':
if (op == "f16x8.convert_i16x8_u"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::ConvertUVecI16x8ToVecF16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
goto parse_error;
}
case 'd':
if (op == "f16x8.div"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::DivVecF16x8));
Expand Down Expand Up @@ -2038,6 +2060,23 @@ switch (buf[0]) {
default: goto parse_error;
}
}
case 't': {
switch (buf[22]) {
case 's':
if (op == "i16x8.trunc_sat_f16x8_s"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::TruncSatSVecF16x8ToVecI16x8));
return Ok{};
}
goto parse_error;
case 'u':
if (op == "i16x8.trunc_sat_f16x8_u"sv) {
CHECK_ERR(makeUnary(ctx, pos, annotations, UnaryOp::TruncSatUVecF16x8ToVecI16x8));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/ir/child-typer.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
case RelaxedTruncUVecF32x4ToVecI32x4:
case RelaxedTruncZeroSVecF64x2ToVecI32x4:
case RelaxedTruncZeroUVecF64x2ToVecI32x4:
case TruncSatSVecF16x8ToVecI16x8:
case TruncSatUVecF16x8ToVecI16x8:
case ConvertSVecI16x8ToVecF16x8:
case ConvertUVecI16x8ToVecF16x8:
case AnyTrueVec128:
case AllTrueVecI8x16:
case AllTrueVecI16x8:
Expand Down
4 changes: 4 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case RelaxedTruncUVecF32x4ToVecI32x4:
case RelaxedTruncZeroSVecF64x2ToVecI32x4:
case RelaxedTruncZeroUVecF64x2ToVecI32x4:
case TruncSatSVecF16x8ToVecI16x8:
case TruncSatUVecF16x8ToVecI16x8:
case ConvertSVecI16x8ToVecF16x8:
case ConvertUVecI16x8ToVecF16x8:
ret = 1;
break;
case InvalidUnary:
Expand Down
8 changes: 8 additions & 0 deletions src/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,18 @@ class Literal {
Literal extendS32() const;
Literal wrapToI32() const;

Literal convertSIToF16() const;
Literal convertUIToF16() const;
Literal convertSIToF32() const;
Literal convertUIToF32() const;
Literal convertSIToF64() const;
Literal convertUIToF64() const;
Literal convertF32ToF16() const;

Literal truncSatToSI16() const;
Literal truncSatToSI32() const;
Literal truncSatToSI64() const;
Literal truncSatToUI16() const;
Literal truncSatToUI32() const;
Literal truncSatToUI64() const;

Expand Down Expand Up @@ -693,6 +697,10 @@ class Literal {
Literal truncSatZeroUToI32x4() const;
Literal demoteZeroToF32x4() const;
Literal promoteLowToF64x2() const;
Literal truncSatToSI16x8() const;
Literal truncSatToUI16x8() const;
Literal convertSToF16x8() const;
Literal convertUToF16x8() const;
Literal swizzleI8x16(const Literal& other) const;
Literal relaxedMaddF16x8(const Literal& left, const Literal& right) const;
Literal relaxedNmaddF16x8(const Literal& left, const Literal& right) const;
Expand Down
12 changes: 12 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,18 @@ struct PrintExpressionContents
case RelaxedTruncZeroUVecF64x2ToVecI32x4:
o << "i32x4.relaxed_trunc_f64x2_u_zero";
break;
case TruncSatSVecF16x8ToVecI16x8:
o << "i16x8.trunc_sat_f16x8_s";
break;
case TruncSatUVecF16x8ToVecI16x8:
o << "i16x8.trunc_sat_f16x8_u";
break;
case ConvertSVecI16x8ToVecF16x8:
o << "f16x8.convert_i16x8_s";
break;
case ConvertUVecI16x8ToVecF16x8:
o << "f16x8.convert_i16x8_u";
break;
case InvalidUnary:
WASM_UNREACHABLE("unvalid unary operator");
}
Expand Down
10 changes: 10 additions & 0 deletions src/support/safe_integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ int64_t wasm::toSInteger64(double x) {
* 1 11111111 111...11 => 0xffffffff => -nan(0x7fffff)
*/

bool wasm::isInRangeI16TruncS(int32_t i) {
uint32_t u = i;
return (u < 0x47000000U) || (u >= 0x80000000U && u <= 0xc7000000U);
}

bool wasm::isInRangeI32TruncS(int32_t i) {
uint32_t u = i;
return (u < 0x4f000000U) || (u >= 0x80000000U && u <= 0xcf000000U);
Expand All @@ -108,6 +113,11 @@ bool wasm::isInRangeI64TruncS(int32_t i) {
return (u < 0x5f000000U) || (u >= 0x80000000U && u <= 0xdf000000U);
}

bool wasm::isInRangeI16TruncU(int32_t i) {
uint32_t u = i;
return (u < 0x47800000) || (u >= 0x80000000U && u < 0xbf800000U);
}

bool wasm::isInRangeI32TruncU(int32_t i) {
uint32_t u = i;
return (u < 0x4f800000U) || (u >= 0x80000000U && u < 0xbf800000U);
Expand Down
2 changes: 2 additions & 0 deletions src/support/safe_integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ uint64_t toUInteger64(double x);
int64_t toSInteger64(double x);
// The isInRange* functions all expect to be passed the binary representation
// of a float or double.
bool isInRangeI16TruncS(int32_t i);
bool isInRangeI32TruncS(int32_t i);
bool isInRangeI64TruncS(int32_t i);
bool isInRangeI16TruncU(int32_t i);
bool isInRangeI32TruncU(int32_t i);
bool isInRangeI64TruncU(int32_t i);
bool isInRangeI32TruncS(int64_t i);
Expand Down
6 changes: 5 additions & 1 deletion src/tools/fuzzing/fuzzing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3143,7 +3143,11 @@ Expression* TranslateToFuzzReader::makeUnary(Type type) {
CeilVecF16x8,
FloorVecF16x8,
TruncVecF16x8,
NearestVecF16x8)),
NearestVecF16x8,
TruncSatSVecF16x8ToVecI16x8,
TruncSatUVecF16x8ToVecI16x8,
ConvertSVecI16x8ToVecF16x8,
ConvertUVecI16x8ToVecF16x8)),
make(Type::v128)});
}
WASM_UNREACHABLE("invalid value");
Expand Down
4 changes: 4 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,10 @@ enum ASTNodes {
F16x8Max = 0x142,
F16x8Pmin = 0x143,
F16x8Pmax = 0x144,
I16x8TruncSatF16x8S = 0x145,
I16x8TruncSatF16x8U = 0x146,
F16x8ConvertI16x8S = 0x147,
F16x8ConvertI16x8U = 0x148,

// bulk memory opcodes

Expand Down
8 changes: 8 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,14 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return value.demoteZeroToF32x4();
case PromoteLowVecF32x4ToVecF64x2:
return value.promoteLowToF64x2();
case TruncSatSVecF16x8ToVecI16x8:
return value.truncSatToSI16x8();
case TruncSatUVecF16x8ToVecI16x8:
return value.truncSatToUI16x8();
case ConvertSVecI16x8ToVecF16x8:
return value.convertSToF16x8();
case ConvertUVecI16x8ToVecF16x8:
return value.convertUToF16x8();
case InvalidUnary:
WASM_UNREACHABLE("invalid unary op");
}
Expand Down
4 changes: 4 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ enum UnaryOp {

// Half precision SIMD
SplatVecF16x8,
TruncSatSVecF16x8ToVecI16x8,
TruncSatUVecF16x8ToVecI16x8,
ConvertSVecI16x8ToVecF16x8,
ConvertUVecI16x8ToVecF16x8,

InvalidUnary
};
Expand Down
43 changes: 43 additions & 0 deletions src/wasm/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,20 @@ Literal Literal::wrapToI32() const {
return Literal((int32_t)i64);
}

Literal Literal::convertSIToF16() const {
if (type == Type::i32) {
return Literal(fp16_ieee_from_fp32_value(float(i32)));
}
WASM_UNREACHABLE("invalid type");
}

Literal Literal::convertUIToF16() const {
if (type == Type::i32) {
return Literal(fp16_ieee_from_fp32_value(float(uint16_t(i32))));
}
WASM_UNREACHABLE("invalid type");
}

Literal Literal::convertSIToF32() const {
if (type == Type::i32) {
return Literal(float(i32));
Expand Down Expand Up @@ -861,6 +875,14 @@ static Literal saturating_trunc(typename AsInt<F>::type val) {
return Literal(I(std::trunc(bit_cast<F>(val))));
}

Literal Literal::truncSatToSI16() const {
if (type == Type::f32) {
return saturating_trunc<float, int16_t, isInRangeI16TruncS>(
Literal(*this).castToI32().geti32());
}
WASM_UNREACHABLE("invalid type");
}

Literal Literal::truncSatToSI32() const {
if (type == Type::f32) {
return saturating_trunc<float, int32_t, isInRangeI32TruncS>(
Expand All @@ -885,6 +907,14 @@ Literal Literal::truncSatToSI64() const {
WASM_UNREACHABLE("invalid type");
}

Literal Literal::truncSatToUI16() const {
if (type == Type::f32) {
return saturating_trunc<float, uint16_t, isInRangeI16TruncU>(
Literal(*this).castToI32().geti32());
}
WASM_UNREACHABLE("invalid type");
}

Literal Literal::truncSatToUI32() const {
if (type == Type::f32) {
return saturating_trunc<float, uint32_t, isInRangeI32TruncU>(
Expand Down Expand Up @@ -1997,6 +2027,19 @@ Literal Literal::convertUToF32x4() const {
return unary<4, &Literal::getLanesI32x4, &Literal::convertUIToF32>(*this);
}

Literal Literal::truncSatToSI16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::truncSatToSI16>(*this);
}
Literal Literal::truncSatToUI16x8() const {
return unary<8, &Literal::getLanesF16x8, &Literal::truncSatToUI16>(*this);
}
Literal Literal::convertSToF16x8() const {
return unary<8, &Literal::getLanesSI16x8, &Literal::convertSIToF16>(*this);
}
Literal Literal::convertUToF16x8() const {
return unary<8, &Literal::getLanesSI16x8, &Literal::convertUIToF16>(*this);
}

Literal Literal::anyTrueV128() const {
auto lanes = getLanesI32x4();
for (size_t i = 0; i < 4; ++i) {
Expand Down
16 changes: 16 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6522,6 +6522,22 @@ bool WasmBinaryReader::maybeVisitSIMDUnary(Expression*& out, uint32_t code) {
curr = allocator.alloc<Unary>();
curr->op = RelaxedTruncZeroUVecF64x2ToVecI32x4;
break;
case BinaryConsts::I16x8TruncSatF16x8S:
curr = allocator.alloc<Unary>();
curr->op = TruncSatSVecF16x8ToVecI16x8;
break;
case BinaryConsts::I16x8TruncSatF16x8U:
curr = allocator.alloc<Unary>();
curr->op = TruncSatUVecF16x8ToVecI16x8;
break;
case BinaryConsts::F16x8ConvertI16x8S:
curr = allocator.alloc<Unary>();
curr->op = ConvertSVecI16x8ToVecF16x8;
break;
case BinaryConsts::F16x8ConvertI16x8U:
curr = allocator.alloc<Unary>();
curr->op = ConvertUVecI16x8ToVecF16x8;
break;
default:
return false;
}
Expand Down
16 changes: 16 additions & 0 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,22 @@ void BinaryInstWriter::visitUnary(Unary* curr) {
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I32x4RelaxedTruncF64x2UZero);
break;
case TruncSatSVecF16x8ToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8TruncSatF16x8S);
break;
case TruncSatUVecF16x8ToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8TruncSatF16x8U);
break;
case ConvertSVecI16x8ToVecF16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::F16x8ConvertI16x8S);
break;
case ConvertUVecI16x8ToVecF16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::F16x8ConvertI16x8U);
break;
case InvalidUnary:
WASM_UNREACHABLE("invalid unary op");
}
Expand Down
4 changes: 4 additions & 0 deletions src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,10 @@ void FunctionValidator::visitUnary(Unary* curr) {
case RelaxedTruncUVecF32x4ToVecI32x4:
case RelaxedTruncZeroSVecF64x2ToVecI32x4:
case RelaxedTruncZeroUVecF64x2ToVecI32x4:
case TruncSatSVecF16x8ToVecI16x8:
case TruncSatUVecF16x8ToVecI16x8:
case ConvertSVecI16x8ToVecF16x8:
case ConvertUVecI16x8ToVecF16x8:
shouldBeEqual(curr->type, Type(Type::v128), curr, "expected v128 type");
shouldBeEqual(
curr->value->type, Type(Type::v128), curr, "expected v128 operand");
Expand Down
4 changes: 4 additions & 0 deletions src/wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,10 @@ void Unary::finalize() {
case RelaxedTruncUVecF32x4ToVecI32x4:
case RelaxedTruncZeroSVecF64x2ToVecI32x4:
case RelaxedTruncZeroUVecF64x2ToVecI32x4:
case TruncSatSVecF16x8ToVecI16x8:
case TruncSatUVecF16x8ToVecI16x8:
case ConvertSVecI16x8ToVecF16x8:
case ConvertUVecI16x8ToVecF16x8:
type = Type::v128;
break;
case AnyTrueVec128:
Expand Down
Loading

0 comments on commit c3a71ff

Please sign in to comment.