From 9dc3458b2573c90b969abded6ab5d08306e08138 Mon Sep 17 00:00:00 2001 From: Kristof Roomp Date: Wed, 5 Jun 2024 15:51:36 +0200 Subject: [PATCH] Added shift-by variable amount (#159) * add neon implementation * fixed u32x8 min and max that were buggy on avx2 * typo * improved tests * made behavior for > 31 shift values consistant for all platforms. * fix spacing * fix neon behavior * simplify by using Array::from_fn * fromfn should be mut to match array implementation * added verification for from_fn * add back comments accidentally removed * test should be called shr_each not shr_all * added binary and unary op shortcuts to trait * added shr each * add i32x4 and i32x8 * fix nonsimd * fix nonsimd * add comments --- src/i32x4_.rs | 67 +++++++++++++++++++++++ src/i32x8_.rs | 47 ++++++++++++++++ src/u32x4_.rs | 67 +++++++++++++++++++++++ src/u32x8_.rs | 46 ++++++++++++++++ tests/all_tests/main.rs | 8 ++- tests/all_tests/t_i32x4.rs | 27 ++++++++++ tests/all_tests/t_i32x8.rs | 29 ++++++++++ tests/all_tests/t_u32x4.rs | 37 +++++++++++++ tests/all_tests/t_u32x8.rs | 42 +++++++++++++++ tests/all_tests/t_usefulness.rs | 96 +++++++++++++++++++++++++++++++++ 10 files changed, 465 insertions(+), 1 deletion(-) diff --git a/src/i32x4_.rs b/src/i32x4_.rs index a93caa51..70cba149 100644 --- a/src/i32x4_.rs +++ b/src/i32x4_.rs @@ -324,6 +324,73 @@ macro_rules! impl_shr_t_for_i32x4 { } impl_shr_t_for_i32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shr) +impl Shr for i32x4 { + type Output = Self; + fn shr(self, rhs: i32x4) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31)); + Self { sse: shr_each_i32_m128i(self.sse, shift_by) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + unsafe { + // mask the shift count to 31 to have same behavior on all platforms + // no right shift, have to pass negative value to left shift on neon + let shift_by = vnegq_s32(vandq_s32(rhs.neon, vmovq_n_s32(31))); + Self { neon: vshlq_s32(self.neon, shift_by) } + } + } else { + let arr: [i32; 4] = cast(self); + let rhs: [i32; 4] = cast(rhs); + cast([ + arr[0].wrapping_shr(rhs[0] as u32), + arr[1].wrapping_shr(rhs[1] as u32), + arr[2].wrapping_shr(rhs[2] as u32), + arr[3].wrapping_shr(rhs[3] as u32), + ]) + } + } + } +} + +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-left; yields self << mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shl) +impl Shl for i32x4 { + type Output = Self; + fn shl(self, rhs: i32x4) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31)); + Self { sse: shl_each_u32_m128i(self.sse, shift_by) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + unsafe { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = vandq_s32(rhs.neon, vmovq_n_s32(31)); + Self { neon: vshlq_s32(self.neon, shift_by) } + } + } else { + let arr: [i32; 4] = cast(self); + let rhs: [i32; 4] = cast(rhs); + cast([ + arr[0].wrapping_shl(rhs[0] as u32), + arr[1].wrapping_shl(rhs[1] as u32), + arr[2].wrapping_shl(rhs[2] as u32), + arr[3].wrapping_shl(rhs[3] as u32), + ]) + } + } + } +} + impl CmpEq for i32x4 { type Output = Self; #[inline] diff --git a/src/i32x8_.rs b/src/i32x8_.rs index e187429e..759f5e88 100644 --- a/src/i32x8_.rs +++ b/src/i32x8_.rs @@ -229,6 +229,53 @@ macro_rules! impl_shr_t_for_i32x8 { impl_shr_t_for_i32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shr) +impl Shr for i32x8 { + type Output = Self; + fn shr(self, rhs: i32x8) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // ensure same behavior as scalar + let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31)); + Self { avx2: shr_each_i32_m256i(self.avx2, shift_by ) } + } else { + Self { + a : self.a.shr(rhs.a), + b : self.b.shr(rhs.b), + } + } + } + } +} + +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-left; yields self << mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shl) +impl Shl for i32x8 { + type Output = Self; + fn shl(self, rhs: i32x8) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // ensure same behavior as scalar wrapping_shl by masking the shift count + let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31)); + // shl is the same for unsigned and signed + Self { avx2: shl_each_u32_m256i(self.avx2, shift_by) } + } else { + Self { + a : self.a.shl(rhs.a), + b : self.b.shl(rhs.b), + } + } + } + } +} + impl CmpEq for i32x8 { type Output = Self; #[inline] diff --git a/src/u32x4_.rs b/src/u32x4_.rs index cd8cce45..db68a380 100644 --- a/src/u32x4_.rs +++ b/src/u32x4_.rs @@ -324,6 +324,73 @@ macro_rules! impl_shr_t_for_u32x4 { } impl_shr_t_for_u32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shr) +impl Shr for u32x4 { + type Output = Self; + fn shr(self, rhs: u32x4) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31)); + Self { sse: shr_each_u32_m128i(self.sse, shift_by) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + unsafe { + // mask the shift count to 31 to have same behavior on all platforms + // no right shift, have to pass negative value to left shift on neon + let shift_by = vnegq_s32(vreinterpretq_s32_u32(vandq_u32(rhs.neon, vmovq_n_u32(31)))); + Self { neon: vshlq_u32(self.neon, shift_by) } + } + } else { + let arr: [u32; 4] = cast(self); + let rhs: [u32; 4] = cast(rhs); + cast([ + arr[0].wrapping_shr(rhs[0]), + arr[1].wrapping_shr(rhs[1]), + arr[2].wrapping_shr(rhs[2]), + arr[3].wrapping_shr(rhs[3]), + ]) + } + } + } +} + +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-left; yields self << mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shl) +impl Shl for u32x4 { + type Output = Self; + fn shl(self, rhs: u32x4) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31)); + Self { sse: shl_each_u32_m128i(self.sse, shift_by) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + unsafe { + // mask the shift count to 31 to have same behavior on all platforms + let shift_by = vreinterpretq_s32_u32(vandq_u32(rhs.neon, vmovq_n_u32(31))); + Self { neon: vshlq_u32(self.neon, shift_by) } + } + } else { + let arr: [u32; 4] = cast(self); + let rhs: [u32; 4] = cast(rhs); + cast([ + arr[0].wrapping_shl(rhs[0]), + arr[1].wrapping_shl(rhs[1]), + arr[2].wrapping_shl(rhs[2]), + arr[3].wrapping_shl(rhs[3]), + ]) + } + } + } +} + impl u32x4 { #[inline] #[must_use] diff --git a/src/u32x8_.rs b/src/u32x8_.rs index 698d0f39..5b5c9474 100644 --- a/src/u32x8_.rs +++ b/src/u32x8_.rs @@ -176,6 +176,52 @@ macro_rules! impl_shr_t_for_u32x8 { impl_shr_t_for_u32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shr) +impl Shr for u32x8 { + type Output = Self; + fn shr(self, rhs: u32x8) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // ensure same behavior as scalar wrapping_shr + let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31)); + Self { avx2: shr_each_u32_m256i(self.avx2, shift_by ) } + } else { + Self { + a : self.a.shr(rhs.a), + b : self.b.shr(rhs.b), + } + } + } + } +} + +/// Shifts lanes by the corresponding lane. +/// +/// Bitwise shift-left; yields self << mask(rhs), where mask removes any +/// high-order bits of rhs that would cause the shift to exceed the bitwidth of +/// the type. (same as wrapping_shl) +impl Shl for u32x8 { + type Output = Self; + fn shl(self, rhs: u32x8) -> Self::Output { + pick! { + if #[cfg(target_feature="avx2")] { + // ensure same behavior as scalar wrapping_shl + let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31)); + Self { avx2: shl_each_u32_m256i(self.avx2, shift_by) } + } else { + Self { + a : self.a.shl(rhs.a), + b : self.b.shl(rhs.b), + } + } + } + } +} + impl u32x8 { #[inline] #[must_use] diff --git a/tests/all_tests/main.rs b/tests/all_tests/main.rs index 1c4cb94a..589c7af4 100644 --- a/tests/all_tests/main.rs +++ b/tests/all_tests/main.rs @@ -64,7 +64,13 @@ fn test_random_vector_vs_scalar< } let expected_vec = vector_fn(V::from(a_arr), V::from(b_arr)); - assert_eq!(expected_arr, expected_vec.into()); + assert_eq!( + expected_arr, + expected_vec.into(), + "scalar = {:?} vec = {:?}", + expected_arr, + expected_vec.into() + ); } } diff --git a/tests/all_tests/t_i32x4.rs b/tests/all_tests/t_i32x4.rs index 09d51a06..2c9235a2 100644 --- a/tests/all_tests/t_i32x4.rs +++ b/tests/all_tests/t_i32x4.rs @@ -233,3 +233,30 @@ fn impl_i32x4_reduce_max() { assert_eq!(p.reduce_max(), i32::MAX); } } + +#[test] +fn impl_i32x4_shr_each() { + let a = i32x4::from([15313, 52322, -1, 4]); + let shift = i32x4::from([1, 30, 8, 33 /* test masking behavior */]); + let expected = i32x4::from([7656, 0, -1, 2]); + let actual = a >> shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: i32x4, b| a >> b, + |a, b| a.wrapping_shr(b as u32), + ); +} +#[test] +fn impl_i32x4_shl_each() { + let a = i32x4::from([15313, 52322, -1, 4]); + let shift = i32x4::from([1, 30, 8, 33 /* test masking behavior */]); + let expected = i32x4::from([30626, -2147483648, -256, 8]); + let actual = a << shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: i32x4, b| a << b, + |a, b| a.wrapping_shl(b as u32), + ); +} diff --git a/tests/all_tests/t_i32x8.rs b/tests/all_tests/t_i32x8.rs index ad481c6b..c9b1c143 100644 --- a/tests/all_tests/t_i32x8.rs +++ b/tests/all_tests/t_i32x8.rs @@ -321,3 +321,32 @@ fn impl_i32x8_reduce_max() { assert_eq!(p.reduce_max(), i32::MAX); } } + +#[test] +fn impl_i32x4_shr_each() { + let a = u32x8::from([15313, 52322, u32::MAX, 4, 10, 20, 30, 40]); + let shift = + u32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]); + let expected = u32x8::from([7656, 0, 16777215, 2, 5, 5, 3, 2]); + let actual = a >> shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: i32x8, b| a >> b, + |a, b| a.wrapping_shr(b as u32), + ); +} +#[test] +fn impl_i32x8_shl_each() { + let a = i32x8::from([15313, 52322, -1, 4, 1, 2, 3, 4]); + let shift = + i32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]); + let expected = i32x8::from([30626, -2147483648, -256, 8, 2, 8, 24, 64]); + let actual = a << shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: i32x8, b| a << b, + |a, b| a.wrapping_shl(b as u32), + ); +} diff --git a/tests/all_tests/t_u32x4.rs b/tests/all_tests/t_u32x4.rs index 723d2c17..f70ad7b5 100644 --- a/tests/all_tests/t_u32x4.rs +++ b/tests/all_tests/t_u32x4.rs @@ -170,3 +170,40 @@ fn impl_u32x4_min() { crate::test_random_vector_vs_scalar(|a: u32x4, b| a.min(b), |a, b| a.min(b)); } + +#[test] +fn impl_u32x4_not() { + let a = u32x4::from([15313, 52322, u32::MAX, 4]); + let expected = u32x4::from([4294951982, 4294914973, 0, 4294967291]); + let actual = !a; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar(|a: u32x4, _b| !a, |a, _b| !a); +} + +#[test] +fn impl_u32x4_shr_each() { + let a = u32x4::from([15313, 52322, u32::MAX, 4]); + let shift = u32x4::from([1, 30, 8, 33 /* test masking behavior */]); + let expected = u32x4::from([7656u32, 0, 16777215, 2]); + let actual = a >> shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x4, b| a >> b, + |a, b| a.wrapping_shr(b), + ); +} +#[test] +fn impl_u32x4_shl_each() { + let a = u32x4::from([15313, 52322, u32::MAX, 4]); + let shift = u32x4::from([1, 30, 8, 33 /* test masking behavior */]); + let expected = u32x4::from([30626, 2147483648, 4294967040, 8]); + let actual = a << shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x4, b| a << b, + |a, b| a.wrapping_shl(b), + ); +} diff --git a/tests/all_tests/t_u32x8.rs b/tests/all_tests/t_u32x8.rs index 95c02a23..44b324f8 100644 --- a/tests/all_tests/t_u32x8.rs +++ b/tests/all_tests/t_u32x8.rs @@ -203,3 +203,45 @@ fn impl_u32x8_min() { crate::test_random_vector_vs_scalar(|a: u32x8, b| a.min(b), |a, b| a.min(b)); } + +#[test] +fn impl_u32x4_shr_each() { + let a = u32x8::from([15313, 52322, u32::MAX, 4, 10, 20, 30, 40]); + let shift = + u32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]); + let expected = u32x8::from([7656, 0, 16777215, 2, 5, 5, 3, 2]); + let actual = a >> shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x8, b| a >> b, + |a, b| a.wrapping_shr(b), + ); +} +#[test] +fn impl_u32x8_shl_each() { + let a = u32x8::from([15313, 52322, u32::MAX, 4, 1, 2, 3, 4]); + let shift = + u32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]); + let expected = u32x8::from([30626, 2147483648, 4294967040, 8, 2, 8, 24, 64]); + let actual = a << shift; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x8, b| a << b, + |a, b| a.wrapping_shl(b), + ); +} + +#[test] +fn impl_u32x8_not() { + let a = u32x8::from([15313, 52322, u32::MAX, 4, 1, 2, 3, 4]); + let expected = u32x8::from([ + 4294951982, 4294914973, 0, 4294967291, 4294967294, 4294967293, 4294967292, + 4294967291, + ]); + let actual = !a; + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar(|a: u32x8, _b| !a, |a, _b| !a); +} diff --git a/tests/all_tests/t_usefulness.rs b/tests/all_tests/t_usefulness.rs index 1071d2b5..5e6baaa7 100644 --- a/tests/all_tests/t_usefulness.rs +++ b/tests/all_tests/t_usefulness.rs @@ -326,3 +326,99 @@ fn test_dequantize_and_idct_i32() { assert_eq!(expected_output, output); } + +// Example implementation of a branch-free division algorithm using u32x8. + +/// Ported from libdivide. Example to show how to use the branchfree division +/// with this library. +fn internal_gen_branch_free_u32(d: u32) -> (u32, u32) { + fn div_rem(a: u64, b: u64) -> (u64, u64) { + (a / b, a % b) + } + + // branchfree cannot be one or zero + assert!(d > 1); + + let floor_log_2_d = (32u32 - 1) - d.leading_zeros(); + + // Power of 2 + if (d & (d - 1)) == 0 { + // We need to subtract 1 from the shift value in case of an unsigned + // branchfree divider because there is a hardcoded right shift by 1 + // in its division algorithm. Because of this we also need to add back + // 1 in its recovery algorithm. + (0, floor_log_2_d - 1) + } else { + let (proposed_m, rem) = div_rem(1u64 << (floor_log_2_d + 32), d as u64); + + let mut proposed_m = proposed_m as u32; + let rem = rem as u32; + assert!(rem > 0 && rem < d); + + // This power works if e < 2**floor_log_2_d. + // We have to use the general 33-bit algorithm. We need to compute + // (2**power) / d. However, we already have (2**(power-1))/d and + // its remainder. By doubling both, and then correcting the + // remainder, we can compute the larger division. + // don't care about overflow here - in fact, we expect it + proposed_m = proposed_m.wrapping_add(proposed_m); + let twice_rem = rem.wrapping_add(rem); + if twice_rem >= d || twice_rem < rem { + proposed_m += 1; + } + + (1 + proposed_m, floor_log_2_d) + // result.more's shift should in general be ceil_log_2_d. But if we + // used the smaller power, we subtract one from the shift because we're + // using the smaller power. If we're using the larger power, we + // subtract one from the shift because it's taken care of by the add + // indicator. So floor_log_2_d happens to be correct in both cases. + } +} + +/// Generate magic and shift values for branch-free division. +fn generate_branch_free_divide_magic_shift(denom: u32x8) -> (u32x8, u32x8) { + let mut magic = u32x8::ZERO; + let mut shift = u32x8::ZERO; + for i in 0..magic.as_array_ref().len() { + let (m, s) = internal_gen_branch_free_u32(denom.as_array_ref()[i]); + magic.as_array_mut()[i] = m; + shift.as_array_mut()[i] = s; + } + + (magic, shift) +} + +// using the previously generated magic and shift, calculate the division +fn branch_free_divide(numerator: u32x8, magic: u32x8, shift: u32x8) -> u32x8 { + // Returns 32 high bits of the 64 bit result of multiplication of two u32s + let mul_hi = |a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32; + + let a = numerator.as_array_ref(); + let b = magic.as_array_ref(); + + let q = u32x8::from([ + mul_hi(a[0], b[0]), + mul_hi(a[1], b[1]), + mul_hi(a[2], b[2]), + mul_hi(a[3], b[3]), + mul_hi(a[4], b[4]), + mul_hi(a[5], b[5]), + mul_hi(a[6], b[6]), + mul_hi(a[7], b[7]), + ]); + + let t = ((numerator - q) >> 1) + q; + t >> shift +} + +#[test] +fn impl_u32x8_branch_free_divide() { + crate::test_random_vector_vs_scalar( + |a: u32x8, b| { + let (magic, shift) = generate_branch_free_divide_magic_shift(b); + branch_free_divide(a, magic, shift) + }, + |a, b| a / b, + ); +}