open Printf
open Effect
open Effect.Deep
(** * 10.1 From exceptions to effects *)
(** Ordinary exceptions *)
module Plain_Exception = struct
type exn += Conversion_error: string -> exn
let parse_int s =
match int_of_string_opt s with
| Some n -> n
| None -> raise (Conversion_error s)
let sum_stringlist lst =
lst |> List.map parse_int |> List.fold_left (+) 0
let safe_sum_stringlist lst =
match sum_stringlist lst with
| res -> res
| exception Conversion_error s ->
printf "Bad input: %s\n" s; max_int
let _ =
let n = safe_sum_stringlist ["1"; "xxx"; "2"; "yyy"] in
printf "Result: %d\n" n
end
(** Resumable exceptions *)
type _ eff += Conversion_error: string -> int eff
let parse_int s =
match int_of_string_opt s with
| Some n -> n
| None -> perform (Conversion_error s)
let sum_stringlist lst =
lst |> List.map parse_int |> List.fold_left (+) 0
let safe_sum_stringlist lst =
match sum_stringlist lst with
| res -> res
| effect Conversion_error s, k ->
printf "Bad input: %s, replaced with 0\n" s;
continue k 0
let _ =
let n = safe_sum_stringlist ["1"; "xxx"; "2"; "yyy"] in
printf "Result: %d\n" n
(** Emitting, handling, re-emitting "Print" effects *)
type _ eff += Print : string -> unit eff
let print (s: string) : unit = perform (Print s)
let abc () = print "a"; print "b"; print "c"
let output (f: unit -> unit) : unit =
match f () with
| () -> print_newline()
| effect Print s, k -> print_string s; continue k ()
let collect (f: unit -> unit) : string =
match f () with
| () -> ""
| effect Print s, k -> s ^ continue k ()
let reverse f =
match f () with
| () -> ()
| effect Print s, k -> continue k (); print s
let number f =
begin match f () with
| () -> (fun lineno -> ())
| effect Print s, k ->
(fun lineno ->
print (sprintf "%d:%s\n" lineno s);
continue k () (lineno + 1))
end 1
let _ =
printf "Output: "; output abc;
printf "Collect: %S\n" (collect abc);
printf "Reverse output: "; output (fun () -> reverse abc);
printf "Reverse collect: %S\n" (collect (fun () -> reverse abc));
printf "Number:\n"; output (fun () -> number abc);
printf "Reverse number:\n"; output (fun () -> number (fun () -> reverse abc))
(** * 10.3 Control inversion *)
(** Control inversion on iterators *)
type 'a tree = Leaf | Node of 'a tree * 'a * 'a tree
let rec tree_iter (f: 'a -> unit) (t: 'a tree) =
match t with
| Leaf -> ()
| Node(l, x, r) -> tree_iter f l; f x; tree_iter f r
let my_tree = Node(Node(Leaf, 1, Node(Leaf, 2, Leaf)), 3, Node(Leaf, 4, Leaf))
type 'a enum = Done | More of 'a * (unit -> 'a enum)
let rec iter_on_enum f = function
| Done -> ()
| More(x, k) -> f x; iter_on_enum f (k ())
let tree_enumerator (type elt) (t: elt tree) : elt enum =
let module M = struct
type _ eff += Next : elt -> unit eff
end in
match tree_iter (fun x -> perform (M.Next x)) t with
| () -> Done
| effect M.Next x, k -> More(x, continue k)
let _ =
printf "Tree enumeration:";
iter_on_enum (fun x -> printf " %d" x) (tree_enumerator my_tree);
printf "\n"
exception StopIteration
let tree_generator (type elt) (t: elt tree) : unit -> elt =
let module M = struct
type _ eff += Next : elt -> unit eff
end in
let rec next : (unit -> elt) ref = ref (fun () ->
match tree_iter (fun x -> perform (M.Next x)) t with
| () -> raise StopIteration
| effect M.Next x, k -> next := continue k; x)
in fun () -> !next()
let _ =
printf "Tree generation:";
let g = tree_generator my_tree in
try
while true do
printf " %d" (g ())
done
with StopIteration ->
printf "\n"
(** General control inversion *)
let numbers n ~yield =
yield 0;
for i = 1 to n do yield i; yield (-i) done
let _ =
printf "Numbers (direct):";
numbers 10 ~yield: (printf " %d");
printf "\n"
let generator (type elt) (f: yield: (elt -> unit) -> unit)
: unit -> elt =
let module M = struct
type _ eff += Next : elt -> unit eff
end in
let rec next = ref (fun () ->
match f ~yield:(fun x -> perform (M.Next x)) with
| () -> raise StopIteration
| effect M.Next x, k -> next := (fun () -> continue k ()); x)
in fun () -> !next()
let _ =
printf "Number generation:";
let g = generator (numbers 10) in
try
while true do
printf " %d" (g ())
done
with StopIteration ->
printf "\n"
let enumerator (type elt) (f: yield: (elt -> unit) -> unit)
: elt enum =
let module M = struct
type _ eff += Next : elt -> unit eff
end in
match f ~yield:(fun x -> perform (M.Next x)) with
| () -> Done
| effect M.Next x, k -> More(x, fun () -> continue k ())
let _ =
printf "Number enumeration:";
iter_on_enum (fun x -> printf " %d" x) (enumerator (numbers 10));
printf "\n"
(** * 10.4 Cooperative multithreading *)
(** Spawning and yielding *)
type _ eff +=
| Spawn : (unit -> unit) -> unit eff
| Yield : unit eff
| Terminate : unit eff
let spawn f = perform (Spawn f)
let yield () = perform Yield
let terminate () = perform Terminate
let runnable : (unit -> unit) Queue.t = Queue.create()
let suspend f = Queue.add f runnable
let restart () =
match Queue.take_opt runnable with
| None -> ()
| Some f -> f ()
let rec run (f: unit -> unit) =
match f() with
| () -> restart()
| effect Spawn f, k -> suspend (continue k); run f
| effect Terminate, k -> discontinue k Exit; restart ()
| effect Yield, k -> suspend (continue k); restart ()
let process name count =
for n = 1 to count do
printf "%s%d " name n;
yield ()
done
let _ =
printf "Three threads: ";
run (fun () ->
spawn (fun () -> process "A" 5);
spawn (fun () -> process "B" 3);
process "C" 6);
printf "\n"
(** Adding communication channels *)
type 'a channel = {
senders: ('a * (unit, unit) continuation) Queue.t;
receivers: ('a, unit) continuation Queue.t
}
let new_channel () =
{ senders = Queue.create(); receivers = Queue.create() }
type _ eff +=
| Send : 'a channel * 'a -> unit eff
| Recv : 'a channel -> 'a eff
let send ch v = perform (Send(ch, v))
let recv ch = perform (Recv ch)
let rec run (f: unit -> unit) =
match f () with
| () -> restart()
| effect Spawn f, k -> suspend (continue k); run f
| effect Terminate, k -> discontinue k Exit; restart ()
| effect Yield, k -> suspend (continue k); restart ()
| effect Send(ch, v), k ->
begin match Queue.take_opt ch.receivers with
| Some rc -> suspend (continue k); continue rc v
| None -> Queue.add (v, k) ch.senders; restart()
end
| effect Recv ch, k ->
begin match Queue.take_opt ch.senders with
| Some(v, sn) -> suspend (continue sn); continue k v
| None -> Queue.add k ch.receivers; restart()
end
let _ =
let ch = new_channel() in
let rec consume () =
let n = recv ch in
if n < 10 then (printf " %d" n; consume ()) in
let rec produce n =
send ch n; produce (n + 1) in
printf "Producer/consumer:";
run (fun () ->
spawn consume;
produce 0);
printf "\n"
let _ =
let primes = new_channel () in
let odd_numbers c =
let n = ref 3 in
while true do send c !n; n := !n + 2 done in
let rec sieve c1 =
let p = recv c1 in
send primes p;
let c2 = new_channel() in
spawn (fun () -> sieve c2);
while true do
let n = recv c1 in
if n mod p <> 0 then send c2 n
done in
run (fun () ->
let c = new_channel() in
spawn (fun () -> odd_numbers c);
spawn (fun () -> sieve c);
printf "Eratosthenes sieve:";
for _i = 1 to 20 do
printf " %d" (recv primes)
done;
printf "\n")
(** Promises *)
type 'a promise = {
mutable value: 'a option;
mutable waiters: ('a, unit) continuation list
}
let new_promise () : 'a promise = { value = None; waiters = [] }
type _ eff +=
| Await : 'a promise -> 'a eff
| Notify : 'a promise * 'a -> unit eff
let await (p: 'a promise) : 'a = perform (Await p)
let notify (p: 'a promise) (v: 'a): unit = perform (Notify(p, v))
let async (f: unit -> 'a) : 'a promise =
let p = new_promise() in
spawn (fun () -> f () |> notify p);
p
let rec run (f: unit -> unit) =
match f () with
| () -> restart()
| effect Spawn f, k -> suspend (continue k); run f
| effect Terminate, k -> discontinue k Exit; restart ()
| effect Yield, k -> suspend (continue k); restart ()
| effect Notify(p, v), k ->
p.value <- Some v;
List.iter
(fun p -> suspend (fun () -> continue p v))
p.waiters;
p.waiters <- [];
suspend (continue k);
restart ()
| effect Await p, k ->
match p.value with
| Some v -> continue k v
| None -> p.waiters <- k :: p.waiters; restart()
let _ =
let rec sum i j =
if i > j then 0 else (printf "%d " i; yield(); i + sum (i + 1) j) in
printf "Promises: ";
run (fun () ->
let p1 = async (fun () -> sum 1 10)
and p2 = async (fun () -> sum 8 16) in
printf "\nResult: %d\n" (await p1 + await p2))