Sunday, 30 September 2012

Tailrec in Scala

One of the more tricky aspects to functional programming is optimising your code. This is because you are typically working at a higher level of abstraction than procedural languages. The code you write isn't necessarily a natural fit to the way the processor works. For example, if you want to compare the factorial function in C# vs Scala:

C#:

public void Factorial(int n) 
{
    var result = 0;
    for (int i = n; i > 0; i--)
    {
        result *= i;
    }
    return result;
}

Scala:

def factorial(n: Int): Int = 
    if (n == 0) 1
    else n * factorial(n - 1)

The big difference here is that in a procedural language you'll tend to write a loop that side effects on a variable where as in a functional language you'll tend to use recursion or a higher order function to write side-effect free code. Processors are really good at the former.

What we see here is a compromise between performance and readability. The C# code will be quicker because the code will do all its work in the same stack frame. The Scala version will create n stack frames while calling itself recursively. However, the Scala version is easier to reason about because there are no side effects, i.e. you're not rewriting result in a loop. You may disagree for such a simple example, but once you have a couple of loops and several variables, procedural code becomes very difficult to follow.

Most functional languages address this issue by performing an optimisation called tail-call optimisation. However, some languages are better at it than others. In Scala you need to make sure that your recursive call is in the tail position (i.e. the recursive call is the last expression). In the factorial example above, it isn't the tail expression is actually n * factorial(n - 1). (You can read about this in more detail here and here.)

Reasoning about whether recursive functions will or won't be optimised is a problem for beginners because it isn't always obvious what qualifies for optimisation. There are other criteria to take into account such as the call must be a local function and the function must be final. A great feature of Scala is an annotation, @tailrec, that raises a compile error you if your function won't be optimised:

import scala.annotation.tailrec

@tailrec
def factorial(n: Int): Int = 
    if (n == 0) 1
    else n * factorial(n - 1)

This will generate a compile error in this case because the recursive function can't be tail-call optimised.

To allow the compile to optimise this, you will need to find a way to have the tail call be the recursive call, for example:

def factorial(n: Int): Int = {
    @tailrec
    def recFac(n: Int, acc: Int): Int = 
        if (n == 0) acc
        else recFac(n - 1, acc * n)
    recFac(n, 1)
}

I learnt about @tailrec through Martin Odersky's Scala course on Coursera. The course is fantastic, Odersky's lectures are very interesting and tie theory into practical applications very neatly. The practical assignments really take me back to my first year of university.

No comments: