Skip to content

Commit

Permalink
Added shift-by variable amount (#159)
Browse files Browse the repository at this point in the history
* add neon implementation

* fixed u32x8 min and max that were buggy on avx2

* typo

* improved tests

* made behavior for > 31 shift values consistant for all platforms.

* fix spacing

* fix neon behavior

* simplify by using Array::from_fn

* fromfn should be mut to match array implementation

* added verification for from_fn

* add back comments accidentally removed

* test should be called shr_each not shr_all

* added binary and unary op shortcuts to trait

* added shr each

* add i32x4 and i32x8

* fix nonsimd

* fix nonsimd

* add comments
  • Loading branch information
mcroomp committed Jun 5, 2024
1 parent e7fd53d commit 9dc3458
Show file tree
Hide file tree
Showing 10 changed files with 465 additions and 1 deletion.
67 changes: 67 additions & 0 deletions src/i32x4_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,73 @@ macro_rules! impl_shr_t_for_i32x4 {
}
impl_shr_t_for_i32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shr)
impl Shr<i32x4> for i32x4 {
type Output = Self;
fn shr(self, rhs: i32x4) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31));
Self { sse: shr_each_i32_m128i(self.sse, shift_by) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
unsafe {
// mask the shift count to 31 to have same behavior on all platforms
// no right shift, have to pass negative value to left shift on neon
let shift_by = vnegq_s32(vandq_s32(rhs.neon, vmovq_n_s32(31)));
Self { neon: vshlq_s32(self.neon, shift_by) }
}
} else {
let arr: [i32; 4] = cast(self);
let rhs: [i32; 4] = cast(rhs);
cast([
arr[0].wrapping_shr(rhs[0] as u32),
arr[1].wrapping_shr(rhs[1] as u32),
arr[2].wrapping_shr(rhs[2] as u32),
arr[3].wrapping_shr(rhs[3] as u32),
])
}
}
}
}

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields self << mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shl)
impl Shl<i32x4> for i32x4 {
type Output = Self;
fn shl(self, rhs: i32x4) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31));
Self { sse: shl_each_u32_m128i(self.sse, shift_by) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
unsafe {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = vandq_s32(rhs.neon, vmovq_n_s32(31));
Self { neon: vshlq_s32(self.neon, shift_by) }
}
} else {
let arr: [i32; 4] = cast(self);
let rhs: [i32; 4] = cast(rhs);
cast([
arr[0].wrapping_shl(rhs[0] as u32),
arr[1].wrapping_shl(rhs[1] as u32),
arr[2].wrapping_shl(rhs[2] as u32),
arr[3].wrapping_shl(rhs[3] as u32),
])
}
}
}
}

impl CmpEq for i32x4 {
type Output = Self;
#[inline]
Expand Down
47 changes: 47 additions & 0 deletions src/i32x8_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,53 @@ macro_rules! impl_shr_t_for_i32x8 {

impl_shr_t_for_i32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shr)
impl Shr<i32x8> for i32x8 {
type Output = Self;
fn shr(self, rhs: i32x8) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// ensure same behavior as scalar
let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31));
Self { avx2: shr_each_i32_m256i(self.avx2, shift_by ) }
} else {
Self {
a : self.a.shr(rhs.a),
b : self.b.shr(rhs.b),
}
}
}
}
}

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields self << mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shl)
impl Shl<i32x8> for i32x8 {
type Output = Self;
fn shl(self, rhs: i32x8) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// ensure same behavior as scalar wrapping_shl by masking the shift count
let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31));
// shl is the same for unsigned and signed
Self { avx2: shl_each_u32_m256i(self.avx2, shift_by) }
} else {
Self {
a : self.a.shl(rhs.a),
b : self.b.shl(rhs.b),
}
}
}
}
}

impl CmpEq for i32x8 {
type Output = Self;
#[inline]
Expand Down
67 changes: 67 additions & 0 deletions src/u32x4_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,73 @@ macro_rules! impl_shr_t_for_u32x4 {
}
impl_shr_t_for_u32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shr)
impl Shr<u32x4> for u32x4 {
type Output = Self;
fn shr(self, rhs: u32x4) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31));
Self { sse: shr_each_u32_m128i(self.sse, shift_by) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
unsafe {
// mask the shift count to 31 to have same behavior on all platforms
// no right shift, have to pass negative value to left shift on neon
let shift_by = vnegq_s32(vreinterpretq_s32_u32(vandq_u32(rhs.neon, vmovq_n_u32(31))));
Self { neon: vshlq_u32(self.neon, shift_by) }
}
} else {
let arr: [u32; 4] = cast(self);
let rhs: [u32; 4] = cast(rhs);
cast([
arr[0].wrapping_shr(rhs[0]),
arr[1].wrapping_shr(rhs[1]),
arr[2].wrapping_shr(rhs[2]),
arr[3].wrapping_shr(rhs[3]),
])
}
}
}
}

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields self << mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shl)
impl Shl<u32x4> for u32x4 {
type Output = Self;
fn shl(self, rhs: u32x4) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = bitand_m128i(rhs.sse, set_splat_i32_m128i(31));
Self { sse: shl_each_u32_m128i(self.sse, shift_by) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
unsafe {
// mask the shift count to 31 to have same behavior on all platforms
let shift_by = vreinterpretq_s32_u32(vandq_u32(rhs.neon, vmovq_n_u32(31)));
Self { neon: vshlq_u32(self.neon, shift_by) }
}
} else {
let arr: [u32; 4] = cast(self);
let rhs: [u32; 4] = cast(rhs);
cast([
arr[0].wrapping_shl(rhs[0]),
arr[1].wrapping_shl(rhs[1]),
arr[2].wrapping_shl(rhs[2]),
arr[3].wrapping_shl(rhs[3]),
])
}
}
}
}

impl u32x4 {
#[inline]
#[must_use]
Expand Down
46 changes: 46 additions & 0 deletions src/u32x8_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,52 @@ macro_rules! impl_shr_t_for_u32x8 {

impl_shr_t_for_u32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields self >> mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shr)
impl Shr<u32x8> for u32x8 {
type Output = Self;
fn shr(self, rhs: u32x8) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// ensure same behavior as scalar wrapping_shr
let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31));
Self { avx2: shr_each_u32_m256i(self.avx2, shift_by ) }
} else {
Self {
a : self.a.shr(rhs.a),
b : self.b.shr(rhs.b),
}
}
}
}
}

/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields self << mask(rhs), where mask removes any
/// high-order bits of rhs that would cause the shift to exceed the bitwidth of
/// the type. (same as wrapping_shl)
impl Shl<u32x8> for u32x8 {
type Output = Self;
fn shl(self, rhs: u32x8) -> Self::Output {
pick! {
if #[cfg(target_feature="avx2")] {
// ensure same behavior as scalar wrapping_shl
let shift_by = bitand_m256i(rhs.avx2, set_splat_i32_m256i(31));
Self { avx2: shl_each_u32_m256i(self.avx2, shift_by) }
} else {
Self {
a : self.a.shl(rhs.a),
b : self.b.shl(rhs.b),
}
}
}
}
}

impl u32x8 {
#[inline]
#[must_use]
Expand Down
8 changes: 7 additions & 1 deletion tests/all_tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ fn test_random_vector_vs_scalar<
}

let expected_vec = vector_fn(V::from(a_arr), V::from(b_arr));
assert_eq!(expected_arr, expected_vec.into());
assert_eq!(
expected_arr,
expected_vec.into(),
"scalar = {:?} vec = {:?}",
expected_arr,
expected_vec.into()
);
}
}

Expand Down
27 changes: 27 additions & 0 deletions tests/all_tests/t_i32x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,30 @@ fn impl_i32x4_reduce_max() {
assert_eq!(p.reduce_max(), i32::MAX);
}
}

#[test]
fn impl_i32x4_shr_each() {
let a = i32x4::from([15313, 52322, -1, 4]);
let shift = i32x4::from([1, 30, 8, 33 /* test masking behavior */]);
let expected = i32x4::from([7656, 0, -1, 2]);
let actual = a >> shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: i32x4, b| a >> b,
|a, b| a.wrapping_shr(b as u32),
);
}
#[test]
fn impl_i32x4_shl_each() {
let a = i32x4::from([15313, 52322, -1, 4]);
let shift = i32x4::from([1, 30, 8, 33 /* test masking behavior */]);
let expected = i32x4::from([30626, -2147483648, -256, 8]);
let actual = a << shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: i32x4, b| a << b,
|a, b| a.wrapping_shl(b as u32),
);
}
29 changes: 29 additions & 0 deletions tests/all_tests/t_i32x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,32 @@ fn impl_i32x8_reduce_max() {
assert_eq!(p.reduce_max(), i32::MAX);
}
}

#[test]
fn impl_i32x4_shr_each() {
let a = u32x8::from([15313, 52322, u32::MAX, 4, 10, 20, 30, 40]);
let shift =
u32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]);
let expected = u32x8::from([7656, 0, 16777215, 2, 5, 5, 3, 2]);
let actual = a >> shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: i32x8, b| a >> b,
|a, b| a.wrapping_shr(b as u32),
);
}
#[test]
fn impl_i32x8_shl_each() {
let a = i32x8::from([15313, 52322, -1, 4, 1, 2, 3, 4]);
let shift =
i32x8::from([1, 30, 8, 33 /* test masking behavior */, 1, 2, 3, 4]);
let expected = i32x8::from([30626, -2147483648, -256, 8, 2, 8, 24, 64]);
let actual = a << shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: i32x8, b| a << b,
|a, b| a.wrapping_shl(b as u32),
);
}
37 changes: 37 additions & 0 deletions tests/all_tests/t_u32x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,40 @@ fn impl_u32x4_min() {

crate::test_random_vector_vs_scalar(|a: u32x4, b| a.min(b), |a, b| a.min(b));
}

#[test]
fn impl_u32x4_not() {
let a = u32x4::from([15313, 52322, u32::MAX, 4]);
let expected = u32x4::from([4294951982, 4294914973, 0, 4294967291]);
let actual = !a;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(|a: u32x4, _b| !a, |a, _b| !a);
}

#[test]
fn impl_u32x4_shr_each() {
let a = u32x4::from([15313, 52322, u32::MAX, 4]);
let shift = u32x4::from([1, 30, 8, 33 /* test masking behavior */]);
let expected = u32x4::from([7656u32, 0, 16777215, 2]);
let actual = a >> shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: u32x4, b| a >> b,
|a, b| a.wrapping_shr(b),
);
}
#[test]
fn impl_u32x4_shl_each() {
let a = u32x4::from([15313, 52322, u32::MAX, 4]);
let shift = u32x4::from([1, 30, 8, 33 /* test masking behavior */]);
let expected = u32x4::from([30626, 2147483648, 4294967040, 8]);
let actual = a << shift;
assert_eq!(expected, actual);

crate::test_random_vector_vs_scalar(
|a: u32x4, b| a << b,
|a, b| a.wrapping_shl(b),
);
}
Loading

0 comments on commit 9dc3458

Please sign in to comment.