Control structures in programming languages: from goto to algebraic effects

Xavier Leroy

Effect handlers in OCaml (chapter 10)

open Printf
open Effect
open Effect.Deep

(** * 10.1 From exceptions to effects *)

(** Ordinary exceptions *)

module Plain_Exception = struct

type exn += Conversion_error: string -> exn

let parse_int s =
  match int_of_string_opt s with
  | Some n -> n
  | None   -> raise (Conversion_error s)

let sum_stringlist lst =
  lst |> List.map parse_int |> List.fold_left (+) 0

let safe_sum_stringlist lst =
  match sum_stringlist lst with
  | res -> res
  | exception Conversion_error s ->
      printf "Bad input: %s\n" s; max_int

let _ =
  let n = safe_sum_stringlist ["1"; "xxx"; "2"; "yyy"] in
  printf "Result: %d\n" n

end

(** Resumable exceptions *)

type _ eff += Conversion_error: string -> int eff

let parse_int s =
  match int_of_string_opt s with
  | Some n -> n
  | None   -> perform (Conversion_error s)

let sum_stringlist lst =
  lst |> List.map parse_int |> List.fold_left (+) 0

let safe_sum_stringlist lst =
  match sum_stringlist lst with
  | res -> res
  | effect Conversion_error s, k ->
      printf "Bad input: %s, replaced with 0\n" s;
      continue k 0

let _ =
  let n = safe_sum_stringlist ["1"; "xxx"; "2"; "yyy"] in
  printf "Result: %d\n" n

(** Emitting, handling, re-emitting "Print" effects *)

type _ eff += Print : string -> unit eff

let print (s: string) : unit = perform (Print s)

let abc () = print "a"; print "b"; print "c"

let output (f: unit -> unit) : unit =
  match f () with
  | () -> print_newline()
  | effect Print s, k ->  print_string s; continue k ()

let collect (f: unit -> unit) : string =
  match f () with
  | () -> ""
  | effect Print s, k ->  s ^ continue k ()

let reverse f =
  match f () with
  | () -> ()
  | effect Print s, k ->  continue k (); print s

let number f =
  begin match f () with
  | () -> (fun lineno -> ())
  | effect Print s, k ->
      (fun lineno ->
         print (sprintf "%d:%s\n" lineno s);
         continue k () (lineno + 1))
  end 1

let _ =
  printf "Output: "; output abc;
  printf "Collect: %S\n" (collect abc);
  printf "Reverse output: "; output (fun () -> reverse abc);
  printf "Reverse collect: %S\n" (collect (fun () -> reverse abc));
  printf "Number:\n"; output (fun () -> number abc);
  printf "Reverse number:\n"; output (fun () -> number (fun () -> reverse abc))

(** * 10.3 Control inversion *)

(** Control inversion on iterators *)

type 'a tree = Leaf | Node of 'a tree * 'a * 'a tree

let rec tree_iter (f: 'a -> unit) (t: 'a tree) =
  match t with
  | Leaf -> ()
  | Node(l, x, r) -> tree_iter f l; f x; tree_iter f r

let my_tree = Node(Node(Leaf, 1, Node(Leaf, 2, Leaf)), 3, Node(Leaf, 4, Leaf))

type 'a enum = Done | More of 'a * (unit -> 'a enum)

let rec iter_on_enum f = function
  | Done -> ()
  | More(x, k) -> f x; iter_on_enum f (k ())

let tree_enumerator (type elt) (t: elt tree) : elt enum =
  let module M = struct
    type _ eff += Next : elt -> unit eff
  end in
  match tree_iter (fun x -> perform (M.Next x)) t with
  | () -> Done
  | effect M.Next x, k -> More(x, continue k)

let _ =
  printf "Tree enumeration:";
  iter_on_enum (fun x -> printf " %d" x) (tree_enumerator my_tree);
  printf "\n"

exception StopIteration

let tree_generator (type elt) (t: elt tree) : unit -> elt =
  let module M = struct
    type _ eff += Next : elt -> unit eff
  end in
  let rec next : (unit -> elt) ref = ref (fun () ->
    match tree_iter (fun x -> perform (M.Next x)) t with
    | () -> raise StopIteration
    | effect M.Next x, k -> next := continue k; x)
  in fun () -> !next() 

let _ =
  printf "Tree generation:";
  let g = tree_generator my_tree in
  try
    while true do
      printf " %d" (g ())
    done
  with StopIteration ->
    printf "\n"

(** General control inversion *)

let numbers n ~yield =
    yield 0;
    for i = 1 to n do yield i; yield (-i) done

let _ =
  printf "Numbers (direct):";
  numbers 10 ~yield: (printf " %d");
  printf "\n"

let generator (type elt) (f: yield: (elt -> unit) -> unit)
                : unit -> elt =
  let module M = struct
    type _ eff += Next : elt -> unit eff
  end in
  let rec next = ref (fun () ->
    match f ~yield:(fun x -> perform (M.Next x)) with
    | () -> raise StopIteration
    | effect M.Next x, k -> next := (fun () -> continue k ()); x)
  in fun () -> !next() 

let _ =
  printf "Number generation:";
  let g = generator (numbers 10) in
  try
    while true do
      printf " %d" (g ())
    done
  with StopIteration ->
    printf "\n"

let enumerator (type elt) (f: yield: (elt -> unit) -> unit)
                 : elt enum =
  let module M = struct
    type _ eff += Next : elt -> unit eff
  end in
  match f ~yield:(fun x -> perform (M.Next x)) with
  | () -> Done
  | effect M.Next x, k -> More(x, fun () -> continue k ())

let _ =
  printf "Number enumeration:";
  iter_on_enum (fun x -> printf " %d" x) (enumerator (numbers 10));
  printf "\n"

(** * 10.4 Cooperative multithreading *)

(** Spawning and yielding *)

type _ eff +=
  | Spawn : (unit -> unit) -> unit eff
  | Yield : unit eff
  | Terminate : unit eff

let spawn f = perform (Spawn f)
let yield () = perform Yield
let terminate () = perform Terminate

let runnable : (unit -> unit) Queue.t = Queue.create()
let suspend f = Queue.add f runnable
let restart () =
  match Queue.take_opt runnable with
  | None -> ()
  | Some f -> f ()

let rec run (f: unit -> unit) =
  match f() with
  | () -> restart()
  | effect Spawn f, k -> suspend (continue k); run f
  | effect Terminate, k -> discontinue k Exit; restart ()
  | effect Yield, k -> suspend (continue k); restart ()

let process name count =
  for n = 1 to count do
      printf "%s%d " name n;
      yield ()
  done

let _ =
  printf "Three threads: ";
  run (fun () ->
    spawn (fun () -> process "A" 5);
    spawn (fun () -> process "B" 3);
    process "C" 6);
  printf "\n"

(** Adding communication channels *)

type 'a channel = {
    senders: ('a * (unit, unit) continuation) Queue.t;
    receivers: ('a, unit) continuation Queue.t
  }
let new_channel () =
  { senders = Queue.create(); receivers = Queue.create() }

type _ eff +=
  | Send : 'a channel * 'a -> unit eff
  | Recv : 'a channel -> 'a eff
let send ch v = perform (Send(ch, v))
let recv ch = perform (Recv ch)

let rec run (f: unit -> unit) =
  match f () with
  | () -> restart()
  | effect Spawn f, k -> suspend (continue k); run f
  | effect Terminate, k -> discontinue k Exit; restart ()
  | effect Yield, k -> suspend (continue k); restart ()
  | effect Send(ch, v), k ->
      begin match Queue.take_opt ch.receivers with
      | Some rc -> suspend (continue k); continue rc v
      | None    -> Queue.add (v, k) ch.senders; restart()
      end
  | effect Recv ch, k ->
      begin match Queue.take_opt ch.senders with
      | Some(v, sn) -> suspend (continue sn); continue k v
      | None        -> Queue.add k ch.receivers; restart()
      end

let _ =
  let ch = new_channel() in
  let rec consume () =
    let n = recv ch in
    if n < 10 then (printf " %d" n; consume ()) in
  let rec produce n =
    send ch n; produce (n + 1) in
  printf "Producer/consumer:";
  run (fun () ->
    spawn consume;
    produce 0);
  printf "\n"

let _ =
  let primes = new_channel () in
  let odd_numbers c =
    let n = ref 3 in
    while true do send c !n; n := !n + 2 done in
  let rec sieve c1 =
    let p = recv c1 in
    send primes p;
    let c2 = new_channel() in
    spawn (fun () -> sieve c2);
    while true do
      let n = recv c1 in
      if n mod p <> 0 then send c2 n
    done in
  run (fun () ->
    let c = new_channel() in
    spawn (fun () -> odd_numbers c);
    spawn (fun () -> sieve c);
    printf "Eratosthenes sieve:";
    for _i = 1 to 20 do
      printf " %d" (recv primes)
    done;
    printf "\n")

(** Promises *)

type 'a promise = {
  mutable value: 'a option;
  mutable waiters: ('a, unit) continuation list
}
let new_promise () : 'a promise = { value = None; waiters = [] }

type _ eff +=
  | Await  : 'a promise -> 'a eff
  | Notify : 'a promise * 'a -> unit eff

let await (p: 'a promise) : 'a = perform (Await p)
let notify (p: 'a promise) (v: 'a): unit = perform (Notify(p, v))

let async (f: unit -> 'a) : 'a promise =
  let p = new_promise() in
  spawn (fun () -> f () |> notify p);
  p

let rec run (f: unit -> unit) =
  match f () with
  | () -> restart()
  | effect Spawn f, k -> suspend (continue k); run f
  | effect Terminate, k -> discontinue k Exit; restart ()
  | effect Yield, k -> suspend (continue k); restart ()
  | effect Notify(p, v), k ->
      p.value <- Some v;
      List.iter
        (fun p -> suspend (fun () -> continue p v))
        p.waiters;
      p.waiters <- [];
      suspend (continue k);
      restart ()
  | effect Await p, k ->
      match p.value with
      | Some v -> continue k v
      | None -> p.waiters <- k :: p.waiters; restart()

let _ =
  let rec sum i j =
    if i > j then 0 else (printf "%d " i; yield(); i + sum (i + 1) j) in
  printf "Promises: ";
  run (fun () ->
    let p1 = async (fun () -> sum 1 10)
    and p2 = async (fun () -> sum 8 16) in
    printf "\nResult: %d\n" (await p1 + await p2))