Control structures in programming languages: from goto to algebraic effects

Xavier Leroy

CPS programming in OCaml (chapter 7)

(* Chapter 7, "Programming with continuations" *)

open Printf

(* Section 7.1: the factorial function, in direct style and in CPS *)

let rec fact n =
  if n = 0 then 1 else n * fact (n-1)

let rec cps_fact n k =
  if n = 0 then k 1 else cps_fact (n-1) (fun r -> k (n * r))

let _ =
  printf "fact 10 = %d\n" (fact 10);
  cps_fact 10 (fun n -> printf "cps_fact 10 = %d\n" n)  
   
(* Section 7.2: iterators *)

type 'a tree = Leaf | Node of 'a tree * 'a * 'a tree

let my_tree =
  Node(Node(Leaf, 1, Node(Leaf, 2, Leaf)), 3, Node(Node(Leaf, 4, Leaf), 5, Leaf))

let rec tree_iter (f: 'a -> unit) (t: 'a tree) =
  match t with
  | Leaf -> ()
  | Node(l, x, r) -> tree_iter f l; f x; tree_iter f r

let _ =
  printf "tree_iter (direct):";
  tree_iter (fun x -> printf " %d" x) my_tree;
  printf "\n"

let rec tree_iter (f: 'a -> (unit -> 'b) -> 'b)
                  (t: 'a tree) (k: unit -> 'b) =
  match t with
  | Leaf -> k ()
  | Node(l, x, r) ->
      tree_iter f l (fun () -> f x (fun () -> tree_iter f r k))

let _ =
  printf "tree_iter (CPS):";
  tree_iter (fun x k -> printf " %d" x; k ())
            my_tree
            (fun () -> printf "\n")

(* Python-style imperative generator *)

exception StopIteration

let tree_generator (t: 'a tree) : unit -> 'a =
  let rec next =
    ref (fun () ->
      tree_iter (fun x k -> next := k; x)
                t
                (fun () -> raise StopIteration))
  in fun () -> !next ()

let _ =
  printf "tree_generator:";
  let g = tree_generator my_tree in
  try
    while true do printf " %d" (g ()) done
  with StopIteration ->
    printf "\n"

(* Purely functional enumerator *)

type 'a enum = Done | More of 'a * (unit -> 'a enum)

let tree_enumerator (t: 'a tree) : 'a enum =
  tree_iter (fun x k -> More(x, k)) t (fun () -> Done)

let _ =
  let rec print_enum = function
    | Done -> printf "\n"
    | More(x, k) -> printf " %d" x; print_enum (k ()) in
  printf "tree_enumerator:";
  print_enum (tree_enumerator my_tree)

(* 7.3 Multithreading *)

let ready : (unit -> unit) Queue.t = Queue.create ()

let schedule () =
  match Queue.take_opt ready with
  | None -> ()
  | Some k -> k ()

let yield (k: unit -> unit) = Queue.add k ready; schedule()
let terminate () = schedule ()
let spawn (f: unit -> unit) = Queue.add f ready

let process name count =
  let rec proc n =
    if n > count then terminate () else begin
      printf "%s%d " name n;
      yield (fun () -> proc (n + 1))
    end
  in proc 1

let _ =
  printf "Running 3 processes: ";
  spawn (fun () -> process "A" 5);
  spawn (fun () -> process "B" 3);
  process "C" 6;
  printf "done\n"

(* 7.4 Errors and exceptions *)

(* Multiple return points *)

let quadratic a b c k1 k2 =
  let d = b *. b -. 4. *. a *. c in
  if d < 0.0 then k2 () else begin
    let d = sqrt d in
    let x1 = (-. b +. d) /. (2. *. a)
    and x2 = (-. b -. d) /. (2. *. a) in
    k1 x1 x2
  end

let print_solutions a b c =
  quadratic a b c
    (fun x1 x2 -> printf "solutions: %g %g\n" x1 x2)
    (fun () -> printf "no real solutions\n")

let _ =
  printf "x2 + x - 2 = 0  : "; print_solutions 1.0 1.0 (-2.0);
  printf "x2 + 1 = 0 : "; print_solutions 1.0 0.0 1.0

(* 7.5 Backtracking *)

(* Parser combinators for regular expressions, using two continuations. *)

module Regexp1 = struct

type 'a parser = char list -> 'a success -> 'a failure -> 'a
and 'a success = char list -> 'a failure -> 'a
and 'a failure = unit -> 'a

let char c : 'a parser = fun w succ fail ->
  match w with
  | c' :: w' when c' = c -> succ w' fail
  | _ -> fail ()

let epsilon : 'a parser = fun w succ fail -> succ w fail

let empty : 'a parser = fun w succ fail -> fail ()

let seq (r: 'a parser) (r': 'a parser) : 'a parser =
  fun w succ fail ->
    r w (fun w' fail' -> r' w' succ fail') fail

let alt (r: 'a parser) (r': 'a parser) : 'a parser =
  fun w succ fail ->
    r w succ (fun () -> r' w succ fail)

let rec star (r: 'a parser) : 'a parser =
  fun w succ fail ->
    r w (fun w' fail ->
           if List.length w' < List.length w
           then star r w' succ fail
           else fail ())
        (fun () -> succ w fail)

let rec nongreedy_star (r: 'a parser) : 'a parser =
  fun w succ fail ->
  succ w (fun () ->
    r w (fun w' fail ->
           if List.length w' < List.length w
           then nongreedy_star r w' succ fail else fail ())
        fail)

let plus (r: 'a parser) : 'a parser =
  seq r (star r)

let matches w r =
  r w (fun w' fail -> if w' = [] then true else fail ())
      (fun () -> false)

let tests = [
  (['a';'a';'b'], seq (star (char 'a')) (char 'b'));
  (['a';'a';'b';'c';'d'], seq (star (char 'a')) (char 'b'));
  (['a';'b';'a';'b'], plus (seq (char 'a') (char 'b')));
  (['c'], alt (plus (char 'b')) (char 'c'));
  (['a'; 'a'; 'a'], seq (star (char 'a')) (char 'a'));
  (['a'; 'b'], seq (alt (seq (char 'a') (char 'b')) (char 'a')) (char 'b'));
  (['a'], seq (star empty) (char 'a'));
  (['a';'a'], seq (star (alt empty (char 'a'))) (char 'a'))
]

let _ =
  printf "Regular expression matching (1):";
  List.iter
    (fun (w, r) -> printf " %B" (matches w r))
    tests;
  printf "\n"

end

(* 7.6 Generating, filtering and counting *)

(* Parser combinators for regular expressions, using one continuation that returns a Boolean indicating success or failure. *)

module Regexp2 = struct

type parser = char list -> (char list -> bool) -> bool

let char c : parser = fun w k ->
  match w with
  | c' :: w' when c' = c -> k w'
  | _ -> false

let epsilon : parser = fun w k -> k w

let empty : parser = fun w k -> false

let seq (r: parser) (r': parser) : parser = fun w k ->
  r w (fun w' -> r' w' k)

let alt (r: parser) (r': parser) : parser = fun w k ->
  r w k || r' w k

let rec star (r: parser) : parser = fun w k ->
  r w (fun w' -> List.length w' < List.length w && star r w' k)
  || k w

let plus (r: parser) : parser =
  seq r (star r)

let matches w (r: parser) : bool =
  r w (fun w' -> w' = [])

let tests = [
  (['a';'a';'b'], seq (star (char 'a')) (char 'b'));
  (['a';'a';'b';'c';'d'], seq (star (char 'a')) (char 'b'));
  (['a';'b';'a';'b'], plus (seq (char 'a') (char 'b')));
  (['c'], alt (plus (char 'b')) (char 'c'));
  (['a'; 'a'; 'a'], seq (star (char 'a')) (char 'a'));
  (['a'; 'b'], seq (alt (seq (char 'a') (char 'b')) (char 'a')) (char 'b'));
  (['a'], seq (star empty) (char 'a'));
  (['a';'a'], seq (star (alt empty (char 'a'))) (char 'a'))
]

let _ =
  printf "Regular expression matching (2):";
  List.iter
    (fun (w, r) -> printf " %B" (matches w r))
    tests;
  printf "\n"

end

(* Counting solutions *)

let bool k = k false + k true

let rec int lo hi k =
  if lo <= hi then k lo + int (lo + 1) hi k else 0

let _ =
  let n =
    int 1 6 (fun d1 -> int 1 6 (fun d2 -> int 1 6 (fun d3 ->
      if d1 + d2 + d3 >= 16 then 1 else 0))) in
  printf "Number of 3d6 that sum to 16: %d\n" n

let rec avltree h k =
  if h < 0 then 0
  else if h = 0 then k Leaf
  else avltree2 (h-1) (h-1) k
       + avltree2 (h-2) (h-1) k
       + avltree2 (h-1) (h-2) k
and avltree2 hl hr k =
  avltree hl (fun l -> avltree hr (fun r -> k (Node(l, 0, r))))

let _ =
  let n = avltree 4 (fun _ -> 1) in
  printf "Number of AVL trees of height 4: %d\n" n

(* Computing probabilities *)

let bool k = 0.5 *. (k false +. k true)

let int lo hi k =
  let rec sum i = if i <= hi then k i +. sum (i + 1) else 0.0
  in sum lo /. float (hi - lo + 1)

let _ =
  let p =
    int 1 6 (fun d1 -> int 1 6 (fun d2 -> int 1 6 (fun d3 ->
      if d1 + d2 + d3 = 15 then 1.0 else 0.0))) in
  printf "Probability of rolling 15 with a 3d6: %.4f\n" p