We’ve seen that Haskell does some pretty crazy things: it infers types, it supports this crazy pattern matching stuff, it disallows side effects, and it computes things lazily. Let’s look at one last crazy thing it does: it uses recursion for everything! There are no loop control structures.
Normally, when you think of recursion, what’s the first thing that comes to mind?
Performance is probably not what you’re thinking of. I think stack overflow exceptions. Why does recursion lead to stack overflow exceptions? Because every time you call a function, we throw onto the stack a new stack frame holding the function’s parameters, local variables, and
Today we’re going to look at two tricks that can make recursive algorithms have much better performance.
The first of these is called memoization. We define it as follows:
Memoization is the act of caching the result of a routine to speed up subsequent invocations of that routine.
We’ve already seen memoizing once. When we added lazy evaluation to Ruby using our
Lazy class, we delayed the computation by putting it in a block, and then cached its result the first time it was executed. Today we’re going to use a similar idea to make repeated recursive function calls really speedy.
Let’s start off by visualizing a recursive algorithm. I really have no tolerance for several tired old programming problems: the Towers of Hanoi, Fahrenheit to Celsius conversion, the shape hierarchy, and the Fibonacci sequence. However, the familiarity of these problems is actually helpful today. Let’s look at a Fibonacci “tree” for the sixth Fibonacci number:
What do you notice? What happens when we compute the seventh Fibonacci number?
That’s a lot of duplicated calls, eh? Is that a big deal? Maybe not. Let’s try implementing it:
def fib n if n <= 1 n else fib(n - 1) + fib(n - 2) end end n = ARGV.to_i puts fib n
Let’s see how well this scales. We’ll try various values of
n and see what happens.
We’ll find that all this duplication does indeed takes its toll. What if somehow we could cache the value of a “tree” and not have to recurse through it if we’d already done so? Let’s try. But how will we do this? How do we remember what the results are for a previous tree? Well, we can create a dictionary mapping the function’s parameter to its return value. Like this:
$cache = Hash.new def fib n if $cache.has_key? n $cache[n] else value = fib(n - 1) + fib(n - 2) $cache[n] = value end end n = ARGV.to_i puts fib n
This is a whole lot faster. How much work have we saved? Well, look at this tree for
We have turned what was an algorithm of exponential complexity into one of linear complexity—assuming the cost of dictionary lookup is neglible.
That’s great, but having to explicitly add memoization to the function is kind of annoying. What we added had very little to do with the Fibonacci sequence. Hmm… Maybe we could write a higher-order function that automatically wrapped up a function inside a memoizing routine? Yes, we can. In fact, the technique we’ll use is called method wrapping. And we can also get rid of that nasty global cache.
It would be great if we could write something like this:
fib = memoize(fib)
That way we could continue to call
fib like normal. Memoization would just be an implementation detail, as it should be.
Let’s solve this in steps. The memoizing function will need to accept another function as its parameter. We’re going to use Ruby symbols for this, which are kind of like method references in Java 8:
def memoize fid ... end fib = memoize(:fib)
Also, the arity of functions is going to be important here. For the time being, let’s say our function only works with functions of arity 1:
def memoize1 fid ... end fib = memoize1(:fib)
Next, we’ll need a cache:
def memoize1 fid cache = Hash.new ... end
Then let’s overwrite our function with a new one. We’ll use some Ruby magic for this:
def memoize1 fid cache = Hash.new define_method(fid) do |arg| # method body here end end
Hang on. If we overwrite our function with this memoizing wrapper, we won’t be able to call the real one when we need it. Let’s retain a reference to the old one. Here’s a complete wrapper which adds absolutely no extra work around the function:
def memoize1 fid cache = Hash.new f = method(fid) define_method(fid) do |arg| f.call arg end end
Let’s add the memoization work:
def memoize1 fid f = method(fid) cache = Hash.new define_method(fid) do |arg| if cache.has_key? arg cache[arg] else value = f.call arg cache[arg] = value end end end
Closures to the rescue here! They saved us from our global cache. First-class functions to the rescue too! We stored the old function in a local variable that is also referenced by the closure.
For the curious, we can also generalize this code to work for functions of arbitrary arity using Ruby’s splat operator:
def memoize fid f = method(fid) cache = Hash.new define_method(fid) do |*args| if cache.has_key? args cache[args] else value = f.call *args cache[args] = value end end end
Let’s try this out on a different recursive function, for a problem taken from the MICS 2017 programming contest. Suppose you are standing at an intersection of a city. You are trying to get to a certain other intersection. How many ways can you get there, assuming you never go out of your way?
Consider this example where you are 3 blocks north and 2 blocks west of your destination. You could go DDRRR, DRDRR, DRRDR, DRRRD. Or RRDRD, RRRDD, RDDRR, RDRDR, RDRRD, or RRDDR. That’s ten different options.
Like most recursive problems, this feels really obtuse but has a simple solution:
def nroutes x, y if x == 0 && y == 0 # We've arrived! 1 elsif x < 0 || y < 0 # We're gone out of our way. Illegal. 0 else nroutes(x - 1, y) + # How many paths if we right? nroutes(x, y - 1) # How many paths if we down? end end
When we run this, it start to get bogged down very quickly. But there’s a lot of tree revisiting here, just as with Fibonacci. Let’s memoize:
And it’s beautifully fast.
P.S. It’s time for a haiku!
Teachers may seem smart
But before they’ve memoized
They’re pretty much you
P.P.S. Here’s the code we wrote together:
#!/usr/bin/env ruby $cache = Hash.new def fib(n) if !$cache.has_key?(n) if n < 2 $cache[n] = 1 else $cache[n] = fib(n - 2) + fib(n - 1) end end $cache[n] end def memoize1(fid) cache = Hash.new f = method(fid) define_method(fid) do |arg| if !cache.has_key?(arg) cache[arg] = f.call(arg) end cache[arg] end end def memoize2(fid) cache = Hash.new f = method(fid) define_method(fid) do |arg1, arg2| args = [arg1, arg2] if !cache.has_key?(args) cache[args] = f.call(arg1, arg2) end cache[args] end end def fib2(n) if n < 2 1 else fib(n - 2) + fib(n - 1) end end fib2 = memoize1(:fib2) # puts fib2(10) # puts fib2(100) def nroutes(x, y) if x == 0 && y == 0 1 elsif x < 0 || y < 0 0 else nroutes(x - 1, y) + nroutes(x, y - 1) end end nroutes = memoize2(:nroutes) puts nroutes(130, 70)