Loop unrolling, also known as loop unwinding, is a loop transformation technique that attempts to optimize a program's execution speed at the expense of its binary size (...) On modern processors, loop unrolling is often counterproductive, as the increased code size can cause more cache misses
I wanted to make a note about this technique after watching the "Instructions Per Clock" lecture on the Performance-Aware Programming Series by Casey Muratori.
This is the complete C code that generates an array of 4096 elements and sums their values.
#include <stdlib.h>
#define ARRAY_SIZE 4096
#define FIXED_SEED 42
int sum(unsigned int *input, unsigned int size) {
int sum = 0;
for (int i = 0; i < size; i++) {
sum += input[i];
}
return sum;
}
int main() {
unsigned int input[ARRAY_SIZE];
// seed the random number generator with a fixed seed.
srand(FIXED_SEED);
// generate the numbers for the array in the range 0-42.
for (int i = 0; i < ARRAY_SIZE; i++) {
input[i] = rand() % 43;
}
return sum(input, ARRAY_SIZE);
}
We will focus on the sum
function, which generates the following x64:
Note: I ran this with -O1
optimisation, so that we don't get all the clutter
of non-optimised C code.
sum:
; if (unsigned int size) is zero do not waste CPU time and return.
test esi, esi
je .L4
; store (unsigned int* input) in rax
mov rax, rdi
; store (unsigned int size) in esi (kinda redundant)
mov esi, esi
; store the last value of (unsigned int* input) in rcx
lea rcx, [rdi+rsi*4]
; initialise our (int sum) variable
mov edx, 0
.L3:
; add the current value pointed by rax to the sum
add edx, DWORD PTR [rax]
; move to the next element
add rax, 4
; check if the next element is the end of the input and do it again
cmp rax, rcx
jne .L3
.L1:
; move the sum to eax so it can be returned to the caller
mov eax, edx
ret
.L4:
mov edx, 0
jmp .L1
We will do a little adjustment to the main()
function so that we can track
the execution time of this code:
// 10^10
for (long i = 0; i < 10000000000; i++) {
sum(input, ARRAY_SIZE);
}
Which yields the following results:
~ ➤ time /tmp/a.out
/tmp/a.out 2.44s user 0.00s system 99% cpu 2.450 total
The CPU can actually do more than one addition at a time. From Casey's course, we can quote:
Each CPU has some number of instructions per clock it could be doing at peak, and the closer we get to executing that many instructions on our workload, the closer we get to maximizing the potential of that CPU. If a CPU can do four instructions per clock, but we are only doing two instructions per clock on our workload, then there's an extra 2x headroom there.
int sum(unsigned int *input, unsigned int size) {
int sum = 0;
for (int i = 0; i < size; i+=4) {
sum += input[i];
sum += input[i+1];
sum += input[i+2];
sum += input[i+3];
}
return sum;
}
The assembly code is very similar, but now the body of the loop has been changed to this:
sum:
test esi, esi
je .L4
mov rdx, rdi
sub esi, 1
shr esi, 2
mov esi, esi
sal rsi, 4
lea rsi, 16[rdi+rsi]
mov ecx, 0
.L3:
; note how we are adding to eax 4 values at a time!
mov eax, DWORD PTR 4[rdx]
add eax, DWORD PTR [rdx]
add eax, DWORD PTR 8[rdx]
add eax, DWORD PTR 12[rdx]
add ecx, eax
add rdx, 16
cmp rdx, rsi
jne .L3
.L1:
mov eax, ecx
ret
.L4:
mov ecx, 0
jmp .L1
However, no substantial improvement was gained:
~ ➤ time /tmp/a.out
/tmp/a.out 2.44s user 0.00s system 99% cpu 2.447 total
In this case, we are doing more additions in the body loop, but we created
a chain dependency on eax
. The CPU needs to wait for the first add on
eax to finish before the second can be processed and so on...
We will remove that chain of dependency on eax by having extra sum values we can use to store values, so that the CPU can parallelise some of the work.
More on instruction level parallelism can be found in: source
int sum(unsigned int *input, unsigned int size) {
int sum_1 = 0;
int sum_2 = 0;
int sum_3 = 0;
int sum_4 = 0;
for (int i = 0; i < size; i+=4) {
sum_1 += input[i];
sum_2 += input[i+1];
sum_3 += input[i+2];
sum_4 += input[i+3];
}
return sum_1 + sum_2 + sum_3 + sum_4;
}
We end up with the following assembly code:
sum:
test esi, esi
je .L4
mov rax, rdi
sub esi, 1
shr esi, 2
mov esi, esi
sal rsi, 4
lea r8, 16[rdi+rsi]
mov ecx, 0
mov esi, 0
mov edi, 0
mov edx, 0
.L3:
; note how we are adding to different registers now
add edx, DWORD PTR [rax]
add edi, DWORD PTR 4[rax]
add esi, DWORD PTR 8[rax]
add ecx, DWORD PTR 12[rax]
add rax, 16
cmp rax, r8
jne .L3
.L2:
lea eax, [rdx+rdi]
add eax, esi
add eax, ecx
ret
.L4:
mov ecx, 0
mov esi, 0
mov edi, 0
mov edx, 0
jmp .L2
Which yields:
~ ➤ time /tmp/a.out
/tmp/a.out 2.40s user 0.00s system 99% cpu 2.409 total
Not thrilling...
Let's go back to our original listing:
int sum(unsigned int *input, unsigned int size) {
int sum = 0;
for (int i = 0; i < size; i++) {
sum += input[i];
}
return sum;
}
Let's now run it with -O3
to see what GCC gives us.
sum:
; check if (unsigned int size) is zero and just return if so.
mov ecx, esi
test esi, esi
je .L9
; check if rsi equals or below three.
lea eax, -1[rsi]
cmp eax, 2
jbe .L10
; edx is our (unsigned int size) variable.
mov edx, esi
; rax is our (unsigned int *input) variable.
mov rax, rdi
; Streaming SIMD Extensions (packed xor)
; cleans up the register xmm0
pxor xmm0, xmm0
; Divide (size) by four
shr edx, 2
; multiply (size) by 16
; note `sal` is the same as `shl`
sal rdx, 4
; add input + size * 4
add rdx, rdi
.p2align 4,,10
.p2align 3
.L4:
; move a double quad word (rax) into xmm2
; effectivelly moves 4 integers from (input) into xmm2
movdqu xmm2, XMMWORD PTR [rax]
add rax, 16
; pack add all the 4 integers into xmm0
paddd xmm0, xmm2
; check if rax is the last item in the array and otherwise keep sum'ming
cmp rdx, rax
jne .L4
; move the sum result into xmm1
; TODO FINISH ANALYSIS
movdqa xmm1, xmm0
psrldq xmm1, 8
paddd xmm0, xmm1
movdqa xmm1, xmm0
psrldq xmm1, 4
paddd xmm0, xmm1
movd eax, xmm0
test cl, 3
je .L1
mov edx, ecx
and edx, -4
.L3:
movsx rsi, edx
lea r8, 0[0+rsi*4]
add eax, DWORD PTR [rdi+rsi*4]
lea esi, 1[rdx]
cmp esi, ecx
jnb .L1
add edx, 2
add eax, DWORD PTR 4[rdi+r8]
cmp edx, ecx
jnb .L1
add eax, DWORD PTR 8[rdi+r8]
ret
.p2align 4,,10
.p2align 3
.L9:
xor eax, eax
.L1:
ret
.L10:
xor edx, edx
xor eax, eax
jmp .L3
A lot of new instructions in there, so I took the time to comment on what they do.