-
Notifications
You must be signed in to change notification settings - Fork 0
/
find-slice-element-position-in-rust.dj
191 lines (151 loc) · 9.23 KB
/
find-slice-element-position-in-rust.dj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
{date="2024/01/13"}
# Find slice element position in Rust, fast!
I started to learn `Rust` only recently and while exploring [slice methods][] I was a bit surprised that I didn't find any method for finding the position of element in the slice:
{.noline}
``` rust
fn find(haystack: &[u8], needle: u8) -> Option<usize> { ... }
```
I had some experience with `Zig` which has a very useful [`std.mem`][zig stdmem] module with many generic functions including `indexOf`, which internally implements [Boyer-Moore-Horspool][] pattern matching algorithm against generic element type `T`:
{.noline}
``` zig
fn indexOf(comptime T: type, haystack: []const T, needle: []const T) ?usize { ... }
```
[slice methods]: https://doc.rust-lang.org/std/primitive.slice.html
[zig stdmem]: https://ziglang.org/documentation/master/std/#A;std:mem
[Boyer-Moore-Horspool]: https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore%E2%80%93Horspool_algorithm
After discussing with `Rust` experts I quickly got the response that I can just use methods of `Iterator` traits:
```rust
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack.iter().position(|&x| x == needle)
}
```
Nice! But what about performance of this method? At first, I was afraid that using lambda function with closure will lead to poor performance (coming from `Go` with non-`LLVM` based compiler which has pretty limited power of inlining optimization). But, unsurprisingly for most of the developers, `LLVM` (and `Rust`) can optimize this method very nicely and `rustc` produce [very clean][rustc iter] binary with `-C opt-level=3 -C target-cpu=native` release profile flags:
[rustc iter]: https://godbolt.org/z/YrvjKfx1v
```asm
# input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle
# output: rax=None/Some, rdx=Some(v)
example::find:
test rsi, rsi # if haystack.len() == 0
je .LBB0_1 # return None<usize>
mov ecx, edx # b = needle
xor eax, eax # result = None<usize>
xor edx, edx # i = 0
.LBB0_3: # loop {
cmp byte ptr [rdi + rdx], cl # if haystack[i] == b
je .LBB0_4 # return Option::Some(i)
inc rdx # i++
cmp rsi, rdx # if haystack.len() != i
jne .LBB0_3 # continue
mov rdx, rsi # i = haystack.len()
ret # return None()
.LBB0_1: # }
xor edx, edx
xor eax, eax
ret
.LBB0_4:
mov eax, 1
ret
```
Can we improve the method's performance?
### Implementing `find` without early returns
We can notice that for other iterator methods like `filter` compiler will use `SSE / AVX` instructions (if target CPU supports them). Then, what is preventing compiler from using `SIMD` instructions for `position` method? Internally within a team we came to the conclusion that `position` method implementation returns early which makes it harder for `LLVM` to apply `SIMD` (although I have no proofs for that).
We can assume that compiler will be able to vectorize function if it will have predictable amount of operations (static or with simple relation of input properties like slice length). How can we achieve that for the `position` function? Actually, there is a nice way to implement it without `break` in the middle of the loop: we just need to process slice in reverse order! Then, in this case, we can just reassign result variable if we found matching element:
```rust
pub fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
let mut position = None::<usize>;
for (i, &b) in haystack.iter().enumerate().rev() {
if b == needle {
position = Some(i);
}
}
position
}
```
Unfortunately, this doesn't help -- there are still to `SIMD` instructions in the output assembler. But wait, we can notice drastic changes in the [output binary][rustc rev] -- now it seems like compiler unrolled our main loop and compare bytes in chunks of size 8:
[rustc rev]: https://godbolt.org/z/5Eh5rfaW3
```asm
# there is just a part of the assembler, you can find full output by the godbolt link
.LBB0_11:
cmp byte ptr [r8 + r11 - 1], dl
lea r14, [rsi + r11 - 1]
lea r15, [rsi + r11 - 7]
cmovne r14, rcx
cmove rax, rbx
cmp byte ptr [r8 + r11 - 2], dl
lea rcx, [rsi + r11 - 2]
cmovne rcx, r14
cmove rax, rbx
...
cmp byte ptr [r8 + r11 - 8], dl
lea rcx, [rsi + r11 - 8]
cmovne rcx, r15
cmove rax, rbx
```
That's looks promising! Unrolling will help in performance by itself, but we can be on the right path to the successful vectorization guidance for the compiler!
### Vectorization by any means!
At this point, I had no clue of how I can simplify life of the compiler except only one last thing -- we can make slice length constant and hope that this will finally activate vectorization engine in the `LLVM`. Turns out that this was enough! If we will use `[u8; 16]` or `[u8; 32]` types for input arguments -- then `LLVM` [will use][rustc static] `128`-bit or `256`-bit `SSE` / `AVX` registers and corresponding instructions!
[rustc static]: https://godbolt.org/z/csYjj769s
```asm
example::find_branchless:
vmovd xmm0, esi
vpbroadcastb xmm0, xmm0
vpcmpeqb xmm0, xmm0, xmmword ptr [rdi]
vpextrb eax, xmm0, 14
vpextrb ecx, xmm0, 13
vpextrb edx, xmm0, 10
and eax, 1
xor rax, 15
test cl, 1
mov ecx, 13
cmove rcx, rax
vpextrb eax, xmm0, 12
# ... there are a lot of instructions for determining actual position of matched byte ...
vpmovmskb esi, xmm0
cmove rdx, rcx
xor eax, eax
test esi, esi
setne al
ret
```
You can notice that compiler generates truly branch-less code (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of `cmove` ("conditional move") instruction which move value between operands only if the flags register are in the specific state. This instruction has way better performance then ordinary `CMP` / `JEQ` pair and allow to implement simple conditional scenarios like we have in the branch-less implementation of `find` function.
### Vectorized version of `find`
Ok, that's great that we finally forced `rustc` to use vectorization. But current `find` implementation is barely usable because it works only for byte arrays of fixed size! What can we do about that?
Here the actions are straightforward -- we can split our input slice in chunks of bounded size and try to apply our branch-less implementation of `find` method for them. `Rust` has nice [`chunks`][chunks] function which do exactly what we want, let's try to use it:
[chunks]: https://doc.rust-lang.org/std/primitive.slice.html#method.chunks
```rust
fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> { ... }
pub fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack
.chunks(32)
.enumerate()
.find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
}
```
Unfortunately, this doesn't work -- the compiler again produces boring assembly with only unrolling optimization on. But, if we stop and think about it, this is actually expected! Chunking logic make every chunk unpredictable in size -- because there is no guarantees about exact size of the last chunk (and every chunk can be the last one!).
Luckily, `Rust` developer team thought about this and added method [`chunks_exact`][chunks_exact] specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: `remainder`.
[chunks_exact]: https://doc.rust-lang.org/std/primitive.slice.html#method.chunks_exact
This final step allow us to make our dream come true: [vectorized `find` function][final vectorized] with only safe `Rust`!
[final vectorized]: https://godbolt.org/z/Kja5WGjMf
```rust
// bonus: refactoring of find_branchless function to make it more elegant!
fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
return haystack.iter().enumerate()
.filter(|(_, &b)| b == needle)
.rfold(None, |_, (i, _)| Some(i))
}
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
let chunks = haystack.chunks_exact(32);
let remainder = chunks.remainder();
chunks.enumerate()
.find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
.or(find_branchless(remainder, needle).map(|x| (haystack.len() & !0x1f) + x))
}
```
### Benchmarks
The full benchmark source code is available here: [./rust-find-bench](https://github.com/sivukhin/sivukhin.github.io/tree/master/rust-find-bench)
| method | time | speedup |
| :----- | ---: | --: |
| `find_naive` | `366.06 ns` | `x1.0` |
| `find_chunks` | `414.06 ns` | `x0.9` |
| `find_chunks_exact` | `133.53 ns` | `x2.7` |
| `find_chunks_exact_branchless` | `40.48 ns` | *`x9.0`* |