Loop unrolling

Loop unrolling, also known as loop unwinding, is a loop transformation technique that attempts to optimize a program's execution speed at the expense of its binary size (...) On modern processors, loop unrolling is often counterproductive, as the increased code size can cause more cache misses

source

I wanted to make a note about this technique after watching the "Instructions Per Clock" lecture on the Performance-Aware Programming Series by Casey Muratori.

Listing 1

This is the complete C code that generates an array of 4096 elements and sums their values.

#include <stdlib.h>

#define ARRAY_SIZE 4096
#define FIXED_SEED 42

int sum(unsigned int *input, unsigned int size) {
  int sum = 0;

  for (int i = 0; i < size; i++) {
    sum += input[i];
  }

  return sum;
}

int main() {
  unsigned int input[ARRAY_SIZE];

  // seed the random number generator with a fixed seed.
  srand(FIXED_SEED);

  // generate the numbers for the array in the range 0-42.
  for (int i = 0; i < ARRAY_SIZE; i++) {
    input[i] = rand() % 43;
  }

  return sum(input, ARRAY_SIZE);
}

We will focus on the sum function, which generates the following x64:

Note: I ran this with -O1 optimisation, so that we don't get all the clutter of non-optimised C code.

sum:
  ; if (unsigned int size) is zero do not waste CPU time and return.
  test  esi, esi
  je  .L4

  ; store (unsigned int* input) in rax
  mov  rax, rdi

  ; store (unsigned int size) in esi (kinda redundant)
  mov  esi, esi

  ; store the last value of (unsigned int* input) in rcx
  lea  rcx, [rdi+rsi*4]

  ; initialise our (int sum) variable
  mov  edx, 0

.L3:
  ; add the current value pointed by rax to the sum
  add  edx, DWORD PTR [rax]

  ; move to the next element
  add  rax, 4

  ; check if the next element is the end of the input and do it again
  cmp  rax, rcx
  jne  .L3
.L1:
  ; move the sum to eax so it can be returned to the caller
  mov  eax, edx
  ret
.L4:
  mov  edx, 0
  jmp  .L1

We will do a little adjustment to the main() function so that we can track the execution time of this code:

  // 10^10
  for (long i = 0; i < 10000000000; i++) {
    sum(input, ARRAY_SIZE);
  }

Which yields the following results:

~ ➤ time /tmp/a.out
/tmp/a.out  2.44s user 0.00s system 99% cpu 2.450 total

Listing 2

The CPU can actually do more than one addition at a time. From Casey's course, we can quote:

Each CPU has some number of instructions per clock it could be doing at peak, and the closer we get to executing that many instructions on our workload, the closer we get to maximizing the potential of that CPU. If a CPU can do four instructions per clock, but we are only doing two instructions per clock on our workload, then there's an extra 2x headroom there.

int sum(unsigned int *input, unsigned int size) {
  int sum = 0;

  for (int i = 0; i < size; i+=4) {
    sum += input[i];
    sum += input[i+1];
    sum += input[i+2];
    sum += input[i+3];
  }

  return sum;
}

The assembly code is very similar, but now the body of the loop has been changed to this:

sum:
  test  esi, esi
  je  .L4
  mov  rdx, rdi
  sub  esi, 1
  shr  esi, 2
  mov  esi, esi
  sal  rsi, 4
  lea  rsi, 16[rdi+rsi]
  mov  ecx, 0
.L3:
  ; note how we are adding to eax 4 values at a time!
  mov  eax, DWORD PTR 4[rdx]
  add  eax, DWORD PTR [rdx]
  add  eax, DWORD PTR 8[rdx]
  add  eax, DWORD PTR 12[rdx]
  add  ecx, eax
  add  rdx, 16
  cmp  rdx, rsi
  jne  .L3
.L1:
  mov  eax, ecx
  ret
.L4:
  mov  ecx, 0
  jmp  .L1

However, no substantial improvement was gained:

~ ➤ time /tmp/a.out
/tmp/a.out  2.44s user 0.00s system 99% cpu 2.447 total

In this case, we are doing more additions in the body loop, but we created a chain dependency on eax. The CPU needs to wait for the first add on eax to finish before the second can be processed and so on...

Listing 3

We will remove that chain of dependency on eax by having extra sum values we can use to store values, so that the CPU can parallelise some of the work.

More on instruction level parallelism can be found in: source

int sum(unsigned int *input, unsigned int size) {
  int sum_1 = 0;
  int sum_2 = 0;
  int sum_3 = 0;
  int sum_4 = 0;

  for (int i = 0; i < size; i+=4) {
    sum_1 += input[i];
    sum_2 += input[i+1];
    sum_3 += input[i+2];
    sum_4 += input[i+3];
  }

  return sum_1 + sum_2 + sum_3 + sum_4;
}

We end up with the following assembly code:

sum:
  test  esi, esi
  je  .L4
  mov  rax, rdi
  sub  esi, 1
  shr  esi, 2
  mov  esi, esi
  sal  rsi, 4
  lea  r8, 16[rdi+rsi]
  mov  ecx, 0
  mov  esi, 0
  mov  edi, 0
  mov  edx, 0
.L3:
  ; note how we are adding to different registers now
  add  edx, DWORD PTR [rax]
  add  edi, DWORD PTR 4[rax]
  add  esi, DWORD PTR 8[rax]
  add  ecx, DWORD PTR 12[rax]
  add  rax, 16
  cmp  rax, r8
  jne  .L3
.L2:
  lea  eax, [rdx+rdi]
  add  eax, esi
  add  eax, ecx
  ret
.L4:
  mov  ecx, 0
  mov  esi, 0
  mov  edi, 0
  mov  edx, 0
  jmp  .L2

Which yields:

~ ➤ time /tmp/a.out
/tmp/a.out  2.40s user 0.00s system 99% cpu 2.409 total

Not thrilling...

Listing 4

Let's go back to our original listing:

int sum(unsigned int *input, unsigned int size) {
  int sum = 0;

  for (int i = 0; i < size; i++) {
    sum += input[i];
  }

  return sum;
}

Let's now run it with -O3 to see what GCC gives us.

sum:
  ; check if (unsigned int size) is zero and just return if so.
  mov  ecx, esi
  test  esi, esi
  je  .L9

  ; check if rsi equals or below three.
  lea  eax, -1[rsi]
  cmp  eax, 2
  jbe  .L10

  ; edx is our (unsigned int size) variable.
  mov  edx, esi
  ; rax is our (unsigned int *input) variable.
  mov  rax, rdi

  ; Streaming SIMD Extensions (packed xor)
  ; cleans up the register xmm0
  pxor  xmm0, xmm0

  ; Divide (size) by four
  shr  edx, 2
  ; multiply (size) by 16
  ; note `sal` is the same as `shl`
  sal  rdx, 4

  ; add input + size * 4
  add  rdx, rdi
  .p2align 4,,10
  .p2align 3
.L4:

  ; move a double quad word (rax) into xmm2
  ; effectivelly moves 4 integers from (input) into xmm2
  movdqu  xmm2, XMMWORD PTR [rax]
  add  rax, 16
  ; pack add all the 4 integers into xmm0
  paddd  xmm0, xmm2

  ; check if rax is the last item in the array and otherwise keep sum'ming
  cmp  rdx, rax
  jne  .L4

  ; move the sum result into xmm1
  ; TODO FINISH ANALYSIS
  movdqa  xmm1, xmm0
  psrldq  xmm1, 8
  paddd  xmm0, xmm1
  movdqa  xmm1, xmm0
  psrldq  xmm1, 4
  paddd  xmm0, xmm1
  movd  eax, xmm0
  test  cl, 3
  je  .L1
  mov  edx, ecx
  and  edx, -4
.L3:
  movsx  rsi, edx
  lea  r8, 0[0+rsi*4]
  add  eax, DWORD PTR [rdi+rsi*4]
  lea  esi, 1[rdx]
  cmp  esi, ecx
  jnb  .L1
  add  edx, 2
  add  eax, DWORD PTR 4[rdi+r8]
  cmp  edx, ecx
  jnb  .L1
  add  eax, DWORD PTR 8[rdi+r8]
  ret
  .p2align 4,,10
  .p2align 3
.L9:
  xor  eax, eax
.L1:
  ret
.L10:
  xor  edx, edx
  xor  eax, eax
  jmp  .L3

A lot of new instructions in there, so I took the time to comment on what they do.