Exploring the Assembly Code generated by Rust Recursive Tree Traversal

In this article, we will delve deeper into the optimization capabilities of the Rust compiler. Specifically, we will examine how it can optimize tail call recursive functions by mapping them to loops and how it eliminates the need for an enum discriminator when it can be inferred from the context. This builds on the previous discussion of the Rust compiler's ability to optimize away the stack frame of the last function call, as previously discussed in the article "Rust to Assembly: Static vs. Dynamic Dispatch".

Tree structure

We start by defining a tree structure in Rust. The tree structure is an enum with two variants: Leaf and Node. The Leaf variant contains a value of type T, while the Node variant contains a value of type T, a left child of type Box<Tree<T>>, and a right child of type Box<Tree<T>>.

pub enum Tree<T> {
    Node(T, Box<Tree<T>>, Box<Tree<T>>),
    Leaf(T),
}
use Tree::{Leaf, Node};

pub fn sum(tree: &Tree<u64>) -> u64 {
    match tree {
        Leaf(n) => *n,
        Node(n, left, right) => *n + sum(left) + sum(right),
    }
}

Tree representation in memory

As previously discussed in the article "Rust to Assembly: Enum Match", Rust typically inserts a discriminator in the memory representation of an enum to keep track of the currently stored variant.  Therefore, it is expected that the Tree enum defined above would also include a discriminator to track the stored variant. However, it's worth noting that Rust can optimize away the discriminator when it can be inferred from the context, which can help improve memory usage and performance.

Tree structure

In the case of the Tree enum defined earlier, the right field of the Node variant is a Box<Tree<T>>, a pointer to another Tree object. The presence or absence of this field can be used to infer the variant of the Tree object: if the right field is non-null, it means the Tree object is a Node, while if the right field is null, it means the Tree object is a Leaf. Therefore, the right field is a discriminator and a pointer to the next Tree object.

The following diagram shows the memory representation of the Tree object. It also shows the byte offsets on the left side of each Node or Leaf object. These offsets will help you understand the generated assembly code.

Tree realization in Rust

Flow chart of the recursive sum function

The following flow chart will help you understand the generated assembly code. The Rust compiler optimizes away the recursive call to the sum function when the right field of the Node variant is a Leaf variant. This is because the sum function is tail recursive for the sum (right) call.

You can think of converting the recursive call to a loop as a two-step process:

  1. The last call in the function is a sum(right). This structure resembles the calling function sum(tree). The compiler can save the overhead of this call by eliminating the stack frame for the call.

  2. The sum(right) call the compiler wants to eliminate is a call to the same function. Therefore, it can convert the tail call to a loop. This can help improve performance and reduce stack usage.

Tree recursive fold generated assembly flow chart

Annotated assembly code

Now, let us look at the generated assembly code, which is annotated with the flow chart above.

; 🌳 Input: rdi: &Tree<u64>
; 🔢 Output: rax: u64
example::sum:
        push    r15                     ; save callee-saved registers
        push    r14 
        push    rbx
        mov     r14, rdi                 ; r14 = tree
        cmp     qword ptr [rdi + 16], 0  ; if tree is a leaf Node
        je      .LBB0_1                  ; then return the value

        ; else return the sum of the value and the sum of the left and right subtrees
        xor     ebx, ebx                 ; rbx = 0
        mov     r15, qword ptr [rip + example::sum@GOTPCREL] ; r15 = sum

        ; ⤵️ sum (right) Loop starts
.LBB0_3:
        mov     rdi, qword ptr [r14 + 8]  ; rdi = left
        add     rbx, qword ptr [r14]      ; rbx += value
        ; ⭐ Recursive call for the left side of the Node.
        call    r15                       ; rbx += sum(left)
        mov     r14, qword ptr [r14 + 16] ; r14 = right
        add     rbx, rax                  ; rbx += sum(left)
        cmp     qword ptr [r14 + 16], 0   ; check if the right is a Leaf
        jne     .LBB0_3                   ; if right side is a Node, continue
        ; ⤴️ sum(right) loop back to the top

        jmp     .LBB0_4                   ; the right side is a Leaf - proceed to return after adding the value
.LBB0_1:
        ; ⭐ The addition in the next instruction will have no impact.
        xor     ebx, ebx                  ; Set rbx to 0
                                          ; This way the addition in the next instruction will add to 0.
.LBB0_4:
        add     rbx, qword ptr [r14]      ; rbx += value
        mov     rax, rbx                  ; rax = rbx
        pop     rbx                       ; restore callee-saved registers
        pop     r14
        pop     r15
        ret

Experiment with the Compiler Explorer

You can experiment with the code in the Compiler Explorer. You can also add the following recursive functions to the Compiler Explorer to see how the compiler optimizes the recursive calls in different scenarios.

pub fn factorial(n: u64) -> u64 {
    if n == 0 {
        1
    } else {
        n * factorial(n - 1)
    }
}
pub fn fibonacci(n: u32) -> u32 {
    if n == 0 {
        return 0;
    } else if n == 1 {
        return 1;
    } else {
        return fibonacci(n - 1) + fibonacci(n - 2);
    }
}

Key takeaways