14 Jig: jumping to tail calls
14.1 Tail Calls
With Iniquity, we’ve finally introduced some computational power via the mechanism of functions and function calls. Together with the notion of inductive data, which we have in the form of pairs, we can write fixed-sized programs that operate over arbitrarily large data.
The problem, however, is that there are a class of programs that should operate with a fixed amount of memory, but instead consume memory in proportion to the size of the data they operate on. This is unfortunate because a design flaw in our compiler now leads to asympototically bad run-times.
We can correct this problem by generating space-efficient code for function calls when those calls are in tail position.
Let’s call this language Jig.
There are no syntactic additions required: we simply will properly handling function calls.
14.2 What is a Tail Call?
A tail call is a function call that occurs in tail position. What is tail position and why is important to consider function calls made in this position?
Tail position captures the notion of “the last subexpression that needs to be computed.” If the whole program is some expression e, the e is in tail position. Computing e is the last thing (it’s the only thing!) the program needs to compute.
Let’s look at some examples to get a sense of the subexpressions in tail position. If e is in tail position and e is of the form:
(let ((x e0)) e1): then e1 is in tail position, while e0 is not. The reason e0 is not in tail position is because after evaluating it, the e1 still needs to be evaluated. On the other hand, once e0 is evaluated, then whatever e1 evaluates to is what the whole let-expression evaluates to; it is all that remains to compute.
(if e0 e1 e2): then both e1 and e2 are in tail position, while e0 is not. After the e0, then based on its result, either e1 or e2 is evaluated, but whichever it is determines the result of the if expression.
(+ e0 e1): then neither e0 or e1 are in tail position because after both are evaluated, their results still must be added together.
(f e0 ...), where f is a function (define (f x ...) e): then none of the arguments e0 ... are in tail position, because after evaluating them, the function still needs to be applied, but the body of the function, e is in tail position.
The significance of tail position is relevant to the compilation of calls. Consider the compilation of a call as described in Iniquity: function definitions and calls: arguments are pushed on the call stack, then the 'call instruction is issued, which pushes the address of the return point on the stack and jumps to the called position. When the function returns, the return point is popped off the stack and jumped back to.
But if the call is in tail position, what else is there to do? Nothing. So after the call, return transfers back to the caller, who then just returns itself.
This leads to unconditional stack space consumption on every function call, even function calls that don’t need to consume space.
Consider this program:
; (Listof Number) -> Number (define (sum xs) (sum/acc xs 0)) ; (Listof Number) Number -> Number (define (sum/acc xs a) (if (empty? xs) a (sum/acc (cdr xs) (+ (car xs) a))))
The sum/acc function should operate as efficiently as a loop that iterates over the elements of a list accumulating their sum. But, as currently compiled, the function will push stack frames for each call.
Matters become worse if we were re-write this program in a seemingly benign way to locally bind a variable:
; (Listof Number) Number -> Number (define (sum/acc xs a) (if (empty? xs) a (let ((b (+ (car xs) a))) (sum/acc (cdr xs) b))))
Now the function pushes a return point and a local binding for b on every recursive call.
But we know that whatever the recursive call produces is the answer to the overall call to sum. There’s no need for a new return point and there’s no need to keep the local binding of b since there’s no way this program can depend on it after the recursive call. Instead of pushing a new, useless, return point, we should make the call with whatever the current return point. This is the idea of proper tail calls.
An axe to grind: the notion of proper tail calls is often referred to with misleading terminology such as tail call optimization or tail recursion. Optimization seems to imply it is a nice, but optional strategy for implementing function calls. Consequently, a large number of mainstream programming languages, most notably Java, do not properly implement tail calls. But a language without proper tail calls is fundamentally broken. It means that functions cannot reliably be designed to match the structure of the data they operate on. It means iteration cannot be expressed with function calls. There’s really no justification for it. It’s just broken. Similarly, it’s not about recursion alone (although it is critical for recursion), it really is about getting function calls, all calls, right. /rant
14.3 An Interpreter for Proper Calls
Before addressing the issue of compiling proper tail calls, let’s first think about the interpreter, starting from the interpreter we wrote for Iniquity:
#lang racket (provide (all-defined-out)) (require "ast.rkt") ;; type Prog = ;; | `(begin ,@(Listof Defn) ,Expr) ;; | Expr ;; type Defn = `(define (,Variable ,@(Listof Variable)) ,Expr) ;; Prog -> Answer (define (interp p) (match p [(prog ds e) (interp-env e '() ds)])) ;; Expr REnv (Listof Defn) -> Answer (define (interp-env e r ds) (match e [(var-e v) (lookup r v)] [(int-e i) i] [(bool-e b) b] [(nil-e) '()] [(prim-e (? prim? p) es) (let ((as (interp-env* es r ds))) (interp-prim p as))] [(if-e p e1 e2) (match (interp-env p r ds) ['err 'err] [v (if v (interp-env e1 r ds) (interp-env e2 r ds))])] [(let-e (list (binding x def)) body) (match (interp-env def r ds) ['err 'err] [v (interp-env body (ext r x v) ds)])] [(app-e f es) (match (interp-env* es r ds) [(list vs ...) (match (defns-lookup ds f) [(fundef f xs body) ; check arity matches (if (= (length xs) (length vs)) (interp-env body (zip xs vs) ds) 'err)])] [_ 'err])])) ;; (Listof Defn) Symbol -> Defn (define (defns-lookup ds f) (findf (match-lambda [(fundef g _ _) (eq? f g)]) ds)) ;; (Listof Expr) REnv -> (Listof Value) | 'err (define (interp-env* es r ds) (match es ['() '()] [(cons e es) (match (interp-env e r ds) ['err 'err] [v (cons v (interp-env* es r ds))])])) ;; Any -> Boolean (define (prim? x) (and (symbol? x) (memq x '(add1 sub1 + - zero? box unbox empty? cons car cdr)))) ;; Any -> Boolean (define (value? x) (or (integer? x) (boolean? x) (null? x) (and (pair? x) (value? (car x)) (value? (cdr x))))) ;; Prim (Listof Value) -> Answer (define (interp-prim p vs) (match (cons p vs) [(list 'add1 (? integer? i0)) (add1 i0)] [(list 'sub1 (? integer? i0)) (sub1 i0)] [(list 'zero? (? integer? i0)) (zero? i0)] [(list 'box v0) (box v0)] [(list 'unbox (? box? v0)) (unbox v0)] [(list 'empty? v0) (empty? v0)] [(list 'cons v0 v1) (cons v0 v1)] [(list 'car (cons v0 v1)) v0] [(list 'cdr (cons v0 v1)) v1] [(list '+ (? integer? i0) (? integer? i1)) (+ i0 i1)] [(list '- (? integer? i0) (? integer? i1)) (- i0 i1)] [_ 'err])) ;; Env Variable -> Answer (define (lookup env x) (match env ['() 'err] [(cons (list y i) env) (match (symbol=? x y) [#t i] [#f (lookup env x)])])) ;; Env Variable Value -> Value (define (ext r x i) (cons (list x i) r)) (define (zip xs ys) (match* (xs ys) [('() '()) '()] [((cons x xs) (cons y ys)) (cons (list x y) (zip xs ys))]))
What needs to be done to make it implement proper tail calls?
Well... not much. Notice how every Iniquity subexpression that is in tail position is interpreted by a call to interp-env that is itself in tail position in the Racket program!
So long as Racket implements tail calls properly, which is does, then this interpreter implements tail calls properly. The interpreter inherits the property of proper tail calls from the meta-language. This is but one reason to do tail calls correctly. Had we transliterated this program to Java, we’d be in trouble as the interpeter would inherit the lack of tail calls and we would have to re-write the interpreter, but as it is, we’re already done.
14.4 A Compiler with Proper Tail Calls
#lang racket (provide (all-defined-out)) (require "ast.rkt") ;; An immediate is anything ending in #b000 ;; All other tags in mask #b111 are pointers (define result-shift 3) (define result-type-mask (sub1 (arithmetic-shift 1 result-shift))) (define type-imm #b000) (define type-box #b001) (define type-pair #b010) (define type-string #b011) (define type-proc #b100) (define imm-shift (+ 2 result-shift)) (define imm-type-mask (sub1 (arithmetic-shift 1 imm-shift))) (define imm-type-int (arithmetic-shift #b00 result-shift)) (define imm-type-bool (arithmetic-shift #b01 result-shift)) (define imm-type-char (arithmetic-shift #b10 result-shift)) (define imm-type-empty (arithmetic-shift #b11 result-shift)) (define imm-val-false imm-type-bool) (define imm-val-true (bitwise-ior (arithmetic-shift 1 (add1 imm-shift)) imm-type-bool)) ;; Allocate in 64-bit (8-byte) increments, so pointers ;; end in #b000 and we tag with #b001 for boxes, etc. ;; type CEnv = (Listof (Maybe Variable)) ;; Prog -> Asm (define (compile p) (match p [(prog defs e) (let ((ds (compile-defines defs)) (c0 (compile-entry e))) `(,@c0 ,@ds))])) ;; Expr -> Asm ;; Compile e as the entry point (define (compile-entry e) `(entry ,@(compile-tail-e e '()) ret err (push rbp) (call error))) ;; Expr CEnv -> Asm ;; Compile an expression in tail position (define (compile-tail-e e c) (match e [(var-e v) (compile-variable v c)] [(? imm? i) (compile-imm i)] [(prim-e (? prim? p) es) (compile-prim p es c)] [(if-e p t f) (compile-tail-if p t f c)] [(let-e (list b) body) (compile-tail-let b body c)] [(app-e f es) (compile-tail-call f es c)])) ;; Expr CEnv -> Asm ;; Compile an expression in non-tail position (define (compile-e e c) (match e [(var-e v) (compile-variable v c)] [(? imm? i) (compile-imm i)] [(prim-e (? prim? p) es) (compile-prim p es c)] [(if-e p t f) (compile-if p t f c)] [(let-e (list b) body) (compile-let b body c)] [(app-e f es) (compile-call f es c)])) ;; Our current set of primitive operations require no function calls, ;; so there's no difference between tail and non-tail call positions (define (compile-prim p es c) (match (cons p es) [`(box ,e0) (compile-box e0 c)] [`(unbox ,e0) (compile-unbox e0 c)] [`(cons ,e0 ,e1) (compile-cons e0 e1 c)] [`(car ,e0) (compile-car e0 c)] [`(cdr ,e0) (compile-cdr e0 c)] [`(add1 ,e0) (compile-add1 e0 c)] [`(sub1 ,e0) (compile-sub1 e0 c)] [`(zero? ,e0) (compile-zero? e0 c)] [`(empty? ,e0) (compile-empty? e0 c)] [`(+ ,e0 ,e1) (compile-+ e0 e1 c)] [_ (error (format "prim applied to wrong number of args: ~a ~a" p es))])) ;; Variable (Listof Expr) CEnv -> Asm ;; Statically know the function we're calling (define (compile-call f es c) (let ((cs (compile-es es (cons #f c))) (stack-size (* 8 (length c)))) `(,@cs (sub rsp ,stack-size) (call ,(symbol->label f)) (add rsp ,stack-size)))) ;; Variable (Listof Expr) CEnv -> Asm ;; Compile a call in tail position (define (compile-tail-call f es c) (let ((cs (compile-es es c))) `(,@cs ,@(move-args (length es) (- (length c))) (jmp ,(symbol->label f))))) ;; Integer Integer -> Asm ;; Move i arguments upward on stack by offset off (define (move-args i off) (match i [0 '()] [_ `(,@(move-args (sub1 i) off) (mov rbx (offset rsp ,(- off i))) (mov (offset rsp ,(- i)) rbx))])) ;; (Listof Expr) CEnv -> Asm (define (compile-es es c) (match es ['() '()] [(cons e es) (let ((c0 (compile-e e c)) (cs (compile-es es (cons #f c)))) `(,@c0 (mov (offset rsp ,(- (add1 (length c)))) rax) ,@cs))])) ;; Variable (Listof Variable) Expr -> Asm (define (compile-define def) (match def [(fundef name args body) (let ((c0 (compile-e body (reverse args)))) `(,(symbol->label name) ,@c0 ret))])) ;; (Listof Variable) (Listof (Listof Variable)) (Listof Expr) -> Asm (define (compile-defines defs) (append-map compile-define defs)) ;; Any -> Boolean (define (imm? x) (or (int-e? x) (bool-e? x) (char-e? x) (nil-e? x))) ;; Imm -> Asm (define (compile-imm i) `((mov rax ,(imm->bits i)))) ;; Imm -> Integer (define (imm->bits i) (match i [(int-e i) (arithmetic-shift i imm-shift)] [(char-e c) (+ (arithmetic-shift (char->integer c) imm-shift) imm-type-char)] [(bool-e b) (if b imm-val-true imm-val-false)] [(nil-e) imm-type-empty])) ;; Variable CEnv -> Asm (define (compile-variable x c) (let ((i (lookup x c))) `((mov rax (offset rsp ,(- (add1 i))))))) ;; Expr CEnv -> Asm (define (compile-box e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 (mov (offset rdi 0) rax) (mov rax rdi) (or rax ,type-box) (add rdi 8)))) ; allocate 8 bytes ;; Expr CEnv -> Asm (define (compile-unbox e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 ,@assert-box (xor rax ,type-box) (mov rax (offset rax 0))))) ;; Expr Expr CEnv -> Asm (define (compile-cons e0 e1 c) (let ((c0 (compile-e e0 c)) (c1 (compile-e e1 (cons #f c)))) `(,@c0 (mov (offset rsp ,(- (add1 (length c)))) rax) ,@c1 (mov (offset rdi 0) rax) (mov rax (offset rsp ,(- (add1 (length c))))) (mov (offset rdi 1) rax) (mov rax rdi) (or rax ,type-pair) (add rdi 16)))) ;; Expr CEnv -> Asm (define (compile-car e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 ,@assert-pair (xor rax ,type-pair) (mov rax (offset rax 1))))) ;; Expr CEnv -> Asm (define (compile-cdr e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 ,@assert-pair (xor rax ,type-pair) (mov rax (offset rax 0))))) ;; Expr CEnv -> Asm (define (compile-empty? e0 c) (let ((c0 (compile-e e0 c)) (l0 (gensym))) `(,@c0 (and rax ,imm-type-mask) (cmp rax ,imm-type-empty) (mov rax ,imm-val-false) (jne ,l0) (mov rax ,imm-val-true) ,l0))) ;; Expr CEnv -> Asm (define (compile-add1 e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 ,@assert-integer (add rax ,(arithmetic-shift 1 imm-shift))))) ;; Expr CEnv -> Asm (define (compile-sub1 e0 c) (let ((c0 (compile-e e0 c))) `(,@c0 ,@assert-integer (sub rax ,(arithmetic-shift 1 imm-shift))))) ;; Expr CEnv -> Asm (define (compile-zero? e0 c) (let ((c0 (compile-e e0 c)) (l0 (gensym)) (l1 (gensym))) `(,@c0 ,@assert-integer (cmp rax 0) (mov rax ,imm-val-false) (jne ,l0) (mov rax ,imm-val-true) ,l0))) ;; Expr Expr Expr CEnv -> Asm (define (compile-if e0 e1 e2 c) (let ((c0 (compile-e e0 c)) (c1 (compile-e e1 c)) (c2 (compile-e e2 c)) (l0 (gensym)) (l1 (gensym))) `(,@c0 (cmp rax ,imm-val-false) (je ,l0) ,@c1 (jmp ,l1) ,l0 ,@c2 ,l1))) ;; Expr Expr Expr CEnv -> Asm (define (compile-tail-if e0 e1 e2 c) (let ((c0 (compile-e e0 c)) (c1 (compile-tail-e e1 c)) (c2 (compile-tail-e e2 c)) (l0 (gensym)) (l1 (gensym))) `(,@c0 (cmp rax ,imm-val-false) (je ,l0) ,@c1 (jmp ,l1) ,l0 ,@c2 ,l1))) ;; Variable Expr Expr CEnv -> Asm (define (compile-tail-let b e1 c) (match b [(binding v def) (let ((c0 (compile-e def c)) (c1 (compile-tail-e e1 (cons v c)))) `(,@c0 (mov (offset rsp ,(- (add1 (length c)))) rax) ,@c1))] [_ (error "Compile-let can only handle bindings")])) ;; Variable Expr Expr CEnv -> Asm (define (compile-let b e1 c) (match b [(binding v def) (let ((c0 (compile-e def c)) (c1 (compile-e e1 (cons v c)))) `(,@c0 (mov (offset rsp ,(- (add1 (length c)))) rax) ,@c1))] [_ (error "Compile-let can only handle bindings")])) ;; Expr Expr CEnv -> Asm (define (compile-+ e0 e1 c) (let ((c1 (compile-e e1 c)) (c0 (compile-e e0 (cons #f c)))) `(,@c1 ,@assert-integer (mov (offset rsp ,(- (add1 (length c)))) rax) ,@c0 ,@assert-integer (add rax (offset rsp ,(- (add1 (length c)))))))) (define (type-pred->mask p) (match p [(or 'box? 'cons? 'string? 'procedure?) result-type-mask] [_ imm-type-mask])) (define (type-pred->tag p) (match p ['box? type-box] ['cons? type-pair] ['string? type-string] ['procedure? type-proc] ['integer? imm-type-int] ['empty? imm-type-empty] ['char? imm-type-char] ['boolean? imm-type-bool])) ;; Variable CEnv -> Natural (define (lookup x cenv) (match cenv ['() (error "undefined variable:" x)] [(cons y cenv) (match (eq? x y) [#t (length cenv)] [#f (lookup x cenv)])])) (define (assert-type p) `((mov rbx rax) (and rbx ,(type-pred->mask p)) (cmp rbx ,(type-pred->tag p)) (jne err))) (define assert-integer (assert-type 'integer?)) (define assert-box (assert-type 'box?)) (define assert-pair (assert-type 'cons?)) (define assert-string (assert-type 'string?)) (define assert-char (assert-type 'char?)) (define assert-proc (assert-type 'procedure?)) ;; Asm (define assert-natural `(,@assert-integer (cmp rax -1) (jle err))) ;; Asm (define assert-integer-codepoint `((mov rbx rax) (and rbx ,imm-type-mask) (cmp rbx 0) (jne err) (cmp rax ,(arithmetic-shift -1 imm-shift)) (jle err) (cmp rax ,(arithmetic-shift #x10FFFF imm-shift)) (mov rbx rax) (sar rbx ,(+ 11 imm-shift)) (cmp rbx #b11011) (je err))) ;; Symbol -> Label ;; Produce a symbol that is a valid Nasm label (define (symbol->label s) (string->symbol (string-append "label_" (list->string (map (λ (c) (if (or (char<=? #\a c #\z) (char<=? #\A c #\Z) (char<=? #\0 c #\9) (memq c '(#\_ #\$ #\# #\@ #\~ #\. #\?))) c #\_)) (string->list (symbol->string s)))) "_" (number->string (eq-hash-code s) 16))))