Exploring the Assembly Code generated by Rust Recursive Tree Traversal
In this article, we will delve deeper into the optimization capabilities of the Rust compiler. Specifically, we will examine how it can optimize tail call recursive functions by mapping them to loops and how it eliminates the need for an enum discriminator when it can be inferred from the context. This builds on the previous discussion of the Rust compiler's ability to optimize away the stack frame of the last function call, as previously discussed in the article "Rust to Assembly: Static vs. Dynamic Dispatch".
Tree structure
We start by defining a tree structure in Rust. The tree structure is an enum with two variants: Leaf
and Node
. The Leaf
variant contains a value of type T
, while the Node
variant contains a value of type T
, a left child of type Box<Tree<T>>
, and a right child of type Box<Tree<T>>
.
pub enum Tree<T> {
Node(T, Box<Tree<T>>, Box<Tree<T>>),
Leaf(T),
}
use Tree::{Leaf, Node};
pub fn sum(tree: &Tree<u64>) -> u64 {
match tree {
Leaf(n) => *n,
Node(n, left, right) => *n + sum(left) + sum(right),
}
}
Tree representation in memory
As previously discussed in the article "Rust to Assembly: Enum Match", Rust typically inserts a discriminator in the memory representation of an enum
to keep track of the currently stored variant. Therefore, it is expected that the Tree
enum
defined above would also include a discriminator to track the stored variant. However, it's worth noting that Rust can optimize away the discriminator when it can be inferred from the context, which can help improve memory usage and performance.
In the case of the Tree
enum defined earlier, the right
field of the Node
variant is a Box<Tree<T>>
, a pointer to another Tree
object. The presence or absence of this field can be used to infer the variant of the Tree
object: if the right
field is non-null, it means the Tree
object is a Node
, while if the right
field is null, it means the Tree
object is a Leaf
. Therefore, the right
field is a discriminator and a pointer to the next Tree
object.
The following diagram shows the memory representation of the Tree
object. It also shows the byte offsets on the left side of each Node
or Leaf
object. These offsets will help you understand the generated assembly code.
Flow chart of the recursive sum function
The following flow chart will help you understand the generated assembly code. The Rust compiler optimizes away the recursive call to the sum
function when the right
field of the Node
variant is a Leaf
variant. This is because the sum
function is tail recursive for the sum (right)
call.
You can think of converting the recursive call to a loop as a two-step process:
The last call in the function is a
sum(right)
. This structure resembles the calling functionsum(tree)
. The compiler can save the overhead of this call by eliminating the stack frame for the call.The
sum(right)
call the compiler wants to eliminate is a call to the same function. Therefore, it can convert the tail call to a loop. This can help improve performance and reduce stack usage.
Annotated assembly code
Now, let us look at the generated assembly code, which is annotated with the flow chart above.
; 🌳 Input: rdi: &Tree<u64>
; 🔢 Output: rax: u64
example::sum:
push r15 ; save callee-saved registers
push r14
push rbx
mov r14, rdi ; r14 = tree
cmp qword ptr [rdi + 16], 0 ; if tree is a leaf Node
je .LBB0_1 ; then return the value
; else return the sum of the value and the sum of the left and right subtrees
xor ebx, ebx ; rbx = 0
mov r15, qword ptr [rip + example::sum@GOTPCREL] ; r15 = sum
; ⤵️ sum (right) Loop starts
.LBB0_3:
mov rdi, qword ptr [r14 + 8] ; rdi = left
add rbx, qword ptr [r14] ; rbx += value
; ⭐ Recursive call for the left side of the Node.
call r15 ; rbx += sum(left)
mov r14, qword ptr [r14 + 16] ; r14 = right
add rbx, rax ; rbx += sum(left)
cmp qword ptr [r14 + 16], 0 ; check if the right is a Leaf
jne .LBB0_3 ; if right side is a Node, continue
; ⤴️ sum(right) loop back to the top
jmp .LBB0_4 ; the right side is a Leaf - proceed to return after adding the value
.LBB0_1:
; ⭐ The addition in the next instruction will have no impact.
xor ebx, ebx ; Set rbx to 0
; This way the addition in the next instruction will add to 0.
.LBB0_4:
add rbx, qword ptr [r14] ; rbx += value
mov rax, rbx ; rax = rbx
pop rbx ; restore callee-saved registers
pop r14
pop r15
ret
Experiment with the Compiler Explorer
You can experiment with the code in the Compiler Explorer. You can also add the following recursive functions to the Compiler Explorer to see how the compiler optimizes the recursive calls in different scenarios.
- Add the following code to see how the compiler optimizes the recursive
factorial
call to a loop.
pub fn factorial(n: u64) -> u64 {
if n == 0 {
1
} else {
n * factorial(n - 1)
}
}
- Add the
fibonacci
function to see how the compiler optimizes away one of the two recursive calls to thefibonacci
function.
pub fn fibonacci(n: u32) -> u32 {
if n == 0 {
return 0;
} else if n == 1 {
return 1;
} else {
return fibonacci(n - 1) + fibonacci(n - 2);
}
}
Key takeaways
- The Rust compiler optimizes recursive calls to loops only when the recursive calls are tail calls.
- When a recursive function calls itself multiple times, the compiler optimizes the last recursive calls to a loop.
- Rust allows you to write declarative recursive functions, but it's essential to understand the overhead of recursive calls and how they are mapped to loops in specific scenarios.
- The Rust compiler opportunistically optimizes away enum discriminators if it can infer the enum variant from the context.