Overthinking Leetcode's Two Sum with SIMD
Summary
Fancy algorithms are slow when n is small, and n is usually small. - Rob Pike
Updated 2022-10-19
Two Sum is one of the most popular interview questions and leetcode’s problem number 1. If you haven’t tried it, maybe go do it now before you read on. It’s free, quick, and hey you might need to get a different job one day.
The problem: You are given an array and a target number. You have to find two numbers in the array that added together equal the target. Those two numbers are guaranteed to exist. Return the indexes of those two numbers. For example if arr=[3,2,4]
and target=6
, the answer is [1, 2]
. Simple enough for us to spend a blog post overthinking it and proving Rob Pike right.
The correct way to answer a technical interview question is to first provide the brute force solution (you don’t always have to code it up), give it’s Big O runtime, then explain that of course you would never do something so naive, and provide the fancy algorithm your interviewer is looking for. So let’s do that.
All the code and graphs for this post are on github grahamking/two_sum.
The brute force solution: Linear search
pub fn two_sum_linear(target: i32, arr: &[i32]) -> (usize, usize) {
for i in 0..arr.len() {
for j in (i + 1)..arr.len() {
if arr[i] + arr[j] == target {
return (i, j);
}
}
}
(0, 0)
}
That’s about as simple as you can get. Complexity is O(N^2) where N is arr.len()
. Rust fans note that using iteratory things like arr.iter().skip(i).position(..
is about 30% slower in my benchmarks.
The fancy algorithm: Map
The insight the problem is asking for is that if for example the target is 6, and arr[i]
is 2, then you need to find a 4. If you know where to find a 4 you’re done. You do a first pass and build a map of where the values are. This is one of the officially correct, interview passing answers.
pub fn two_sum_map(target: i32, arr: &[i32]) -> (usize, usize) {
let mut m = HashMap::with_capacity(arr.len());
for i in 0..arr.len() {
m.insert(arr[i], i);
}
for i in 0..arr.len() {
if let Some(j) = m.get(&(target - arr[i])) {
if i != *j {
return (i, *j);
}
}
}
(0, 0)
}
Complexity is now O(2* N) because we iterate the array twice with an O(1) map access inside. We drop the contant and get O(N)
. Clearly better. On to the systems design interview!
Note the use of with_capacity
so that we allocate all the space we’re going to need upfront. Without that the map will have to regularly re-allocate and copy as it grows. I forgot to include this in an early draft of this post and the performance hit was very obvious in the graphs.
I think you’ll find it’s a bit more complicated than that
The map version is doing a fair amount of work. It is iterating the whole array at least once, even if the answer is right at the beginning of the array. It is allocating 12 * N bytes on the heap (8 for the usize
key, 4 for the i32
value) and copying the whole array there. It is hashing the values (number don’t hash to themselves under SipHash, Rust’s default hasher).
The linear “naive” solution is doing none of that. It is loading two memory locations (arr[i]
and arr[j]
), adding them and comparing them. The memory locations are contiguous so they will cache very well. We traverse the array in order so the hardware prefetcher will know to fetch memory ahead of us needing it. There must be some values of N for which the brute force solution is faster than the map one. What values?
The answer depends on where in the array we find the two indexes:
- At the start: if the indexes we seek are
[0, 1]
, or close enough to that, the brute force solution will always be faster. - At the end: if the indexes are
[arr.len() - 2, arr.len() -1]
that is the worst case for the linear solution. - Near the middle: this seems like a fair comparison.
- Random: this is what I’d expect in a real world situation. It doesn’t reproduce well in benchmarks, but it should average out to near the middle.
Graph: Linear vs map
Let’s compare the two solutions when the answer we seek is in the middle of the array. I used criterion for benchmarking and generating the graphs (thanks criterion!) and runperf to get consistent results.
The x axis is the size of the array. The y axis is runtime in micro-seconds. The linear search solution has a pleasing textbook O(N^2) plot. The lines intersect around 125. That means in the average case (indexes we are looking for are near the middle) the “naive” solution is faster if the array is smaller than 125 elements. The map solution is better beyond that.
How many connections does a social media user have? How many items does a user put in their shopping cart? How many lines of text can you see in your browser right now? Fewer than 125.
SIMD: A significantly bruter force
If you look at the assembly for two_sum_linear
it translates quite directly from the Rust, comparing a single value at a time:
mov ecx,DWORD PTR [rsi+r11*4+0x4] ; ecx = arr[i]
add ecx,DWORD PTR [rsi+rax*4] ; ecx += arr[j]
cmp ecx,edi ; compare ecx with target (edi)
jne 87b0 ; if not equal jump to start of inner loop
The compiler hasn’t figured out that we want to compare all the values in the array with target
, idealy all at once. We can do this, with SIMD instructions. We can compare 64 bytes at once using AVX-512 instructions, a full cache line! It uses the zmmX
registers (zmm0
, zmm1
, etc).
It works like this:
- We can treat one one of those registers as 16 i32 values (i32=4 bytes, 4*16 = 64 bytes = 512 bits).
- We load one register with 16 copies of the target value:
vpbroadcastd zmm0,r13d
ref - We compare that register i32 by i32 against the next 16 values from the array:
vpcmpeqd k0,zmm0,ZMMWORD PTR [r9+r15*1]
ref. This will set a “mask register” (k0
) with either 0 if the value doesn’t match, or all 1s (0xFF) if it does.
Say we’re looking for the value 8. We will compare the first 8 in zmm0
with 1. They don’t match so the first value of k0
is 0. Then we compare the second 8 with 2, the third 8 with 6, and so on. The fourth value does match so we set that value to 0xFF
in k0
. We say SIMD instructions work in ‘lanes’ - 16 lanes here.
zmm0: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
arr: [9, 2, 6, 8, 5, 3, 7, 4, 1, 0, 1, 8, 3, 6, 5, 6]
= k0: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
We then check if any of the values in k0
are set (technically to 0xFF, not 1), and if they are, we iterate looking for the first position, and that’s our j
index. There is some ceremony involved to set things up and extract the result, but the vpbroadcastd
and vpcmpeqd
instructions both take 3 CPU cycles each ref. Three cycles! That’s the cost of a single L1 cache access. To compare 16 values!
The complication is that the values must be 64-bytes aligned, and we need 64 bytes worth of them (16 i32’s in our case). That means in practice we have to compare a few values before we get to a 64-byte boundary, then we can SIMD the middle of the array, and often there are a few extra values at the end of the array.
Often the compiler will use these instructions automatically. Here we get to do it ourselves. Rust has recently added an amazing portable SIMD module, so we don’t even have to write the assembly.
Sidebar: At time of writing leetcode’s Rust version is 1.58.2, and portable SIMD is nightly, so you can’t actually use this on leetcode. Note also that portable-simd is highly unstable. It changed while I was writing this post.
fn two_sum_simd(target: i32, arr: &[i32]) -> (usize, usize) {
const LANES: usize = 16;
for (i, left) in arr.iter().enumerate() {
let need = target - left;
let (before, simd_main, after) = arr[(i + 1)..].as_simd::<LANES>();
for j in 0..before.len() {
if before[j] == need {
return (i, i + 1 + j);
}
}
let simd_need: Simd<i32, LANES> = Simd::splat(need);
for (chunk_num, chunk) in simd_main.iter().enumerate() {
let mask = chunk.simd_eq(simd_need);
if mask.any() {
// found it
let j = mask.to_bitmask().trailing_zeros() as usize;
return (i, i + 1 + before.len() + chunk_num * LANES + j);
}
}
for j in 0..after.len() {
if after[j] == need {
return (i, i + 1 + j + before.len() + simd_main.len() * LANES);
}
}
}
(0, 0)
}
If your CPU doesn’t support AVX-512 change 16 to 8, to use AVX2.
The before
and after
pieces will always be shorter than 16 values. If we have fewer than 16 values we’re doing the normal linear scan, which we now know is faster than the map solution for such a small N. If N gets bigger than that, we switch into using SIMD instructions. The middle section is the magic, particularly this line:
let mask = chunk.simd_eq(simd_need);
That does the lane-by-lane comparison described earlier using instruction vpcmpeqd
.
Graph: SIMD
Let’s add SIMD to our earlier graph. It is significantly faster than linear scan, and faster than map up to at least 500 items. It’s still O(N^2), so a Big-O analysis would have you believe that SIMD and Linear scan are equivalent. I disagree!
One-pass map
We significantly improved the linear scan solution. Can we also improve the map solution? Yes! We can check the map as we insert, only traversing the array once. In most cases we don’t have to walk the whole array, so it should be faster. This is the answer of most interview win.
pub fn two_sum_map_onepass(target: i32, arr: &[i32]) -> (usize, usize) {
let mut m = HashMap::with_capacity(arr.len());
for i in 0..arr.len() {
if let Some(j) = m.get(&(target - arr[i])) {
return (i, *j);
}
m.insert(arr[i], i);
}
(0, 0)
}
Graph: SIMD vs Map
Let’s extend the graph out to 1,000 items and drop the linear scan, so we can focus on SIMD vs Map. Our initial map solution is “two-pass”, the new one is “one-pass”, which is always better than two-pass.
The SIMD line still grows O(N^2) and Map still wins for very large N. But! SIMD beats two-pass until we have over 600 values in our array. It beats one-pass until around 400 i32
elements.
Two Sum: The perfect solution
If you’re in an interview the perfect solution is the answer the interviewer expects you to give. Give the map one (ideally one-pass) and move on.
Otherwise, the fastest solution depends on how many elements you have.
Knowing something about the data, even something as simple as how much of it there is, allows us to choose the most appropriate algorithm. That isn’t always the fancy algorithm. N is often small.