package eio

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file switch.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
type t = {
  mutable fibers : int;         (* Total, including daemon_fibers and the main function *)
  mutable daemon_fibers : int;
  mutable exs : (exn * Printexc.raw_backtrace) option;
  on_release : (unit -> unit) Lwt_dllist.t;
  waiter : unit Single_waiter.t;              (* The main [top]/[sub] function may wait here for fibers to finish. *)
  cancel : Cancel.t;
}

type hook =
  | Null
  | Hook : Domain.id * 'a Lwt_dllist.node -> hook

let null_hook = Null

(* todo: would be good to make this thread-safe. While a switch can only be turned off from its own domain,
   we might want to allow closing something explicitly from any domain, and that needs to remove the hook. *)
let remove_hook = function
  | Null -> ()
  | Hook (id, n) ->
    if Domain.self () <> id then invalid_arg "Switch hook removed from wrong domain!";
    Lwt_dllist.remove n

let dump f t =
  Fmt.pf f "@[<v2>Switch %d (%d extra fibers):@,%a@]"
    (t.cancel.id :> int)
    t.fibers
    Cancel.dump t.cancel

let is_finished t = Cancel.is_finished t.cancel

(* Check switch belongs to this domain (and isn't finished). It's OK if it's cancelling. *)
let check_our_domain t =
  if is_finished t then invalid_arg "Switch finished!";
  if Domain.self () <> t.cancel.domain then invalid_arg "Switch accessed from wrong domain!"

(* Check isn't cancelled (or finished). *)
let check t =
  if is_finished t then invalid_arg "Switch finished!";
  Cancel.check t.cancel

let get_error t =
  Cancel.get_error t.cancel

let combine_exn ex = function
  | None -> ex
  | Some ex1 -> Exn.combine ex1 ex

(* Note: raises if [t] is finished or called from wrong domain. *)
let fail ?(bt=Printexc.get_callstack 0) t ex =
  check_our_domain t;
  if t.exs = None then
    Trace.error t.cancel.id ex;
  t.exs <- Some (combine_exn (ex, bt) t.exs);
  try
    Cancel.cancel t.cancel ex
  with ex ->
    let bt = Printexc.get_raw_backtrace () in
    t.exs <- Some (combine_exn (ex, bt) t.exs)

let inc_fibers t =
  check t;
  t.fibers <- t.fibers + 1

let dec_fibers t =
  t.fibers <- t.fibers - 1;
  if t.daemon_fibers > 0 && t.fibers = t.daemon_fibers then
    Cancel.cancel t.cancel Exit;
  if t.fibers = 0 then
    Single_waiter.wake t.waiter (Ok ())

let with_op t fn =
  inc_fibers t;
  Fun.protect fn
    ~finally:(fun () -> dec_fibers t)

let with_daemon t fn =
  inc_fibers t;
  t.daemon_fibers <- t.daemon_fibers + 1;
  Fun.protect fn
    ~finally:(fun () ->
        t.daemon_fibers <- t.daemon_fibers - 1;
        dec_fibers t
      )

let or_raise = function
  | Ok x -> x
  | Error ex -> raise ex

let rec await_idle t =
  (* Wait for fibers to finish: *)
  while t.fibers > 0 do
    Trace.try_get t.cancel.id;
    Single_waiter.await t.waiter "Switch.await_idle" t.cancel.id
  done;
  (* Call on_release handlers: *)
  let queue = Lwt_dllist.create () in
  Lwt_dllist.transfer_l t.on_release queue;
  let rec release () =
    match Lwt_dllist.take_opt_r queue with
    | None when t.fibers = 0 && Lwt_dllist.is_empty t.on_release -> ()
    | None -> await_idle t
    | Some fn ->
      begin
        try fn () with
        | ex -> fail t ex
      end;
      release ()
  in
  release ()

let await_idle t = Cancel.protect (fun _ -> await_idle t)

let maybe_raise_exs t =
  match t.exs with
  | None -> ()
  | Some (ex, bt) -> Printexc.raise_with_backtrace ex bt

let create cancel =
  {
    fibers = 1;         (* The main function counts as a fiber *)
    daemon_fibers = 0;
    exs = None;
    waiter = Single_waiter.create ();
    on_release = Lwt_dllist.create ();
    cancel;
  }

let run_internal t fn =
  match fn t with
  | v ->
    dec_fibers t;
    await_idle t;
    Trace.get t.cancel.id;
    maybe_raise_exs t;        (* Check for failure while finishing *)
    (* Success. *)
    v
  | exception ex ->
    let bt = Printexc.get_raw_backtrace () in
    (* Main function failed.
       Turn the switch off to cancel any running fibers, if it's not off already. *)
    dec_fibers t;
    fail ~bt t ex;
    await_idle t;
    Trace.get t.cancel.id;
    maybe_raise_exs t;
    assert false

let run ?name fn = Cancel.sub_checked ?name Switch (fun cc -> run_internal (create cc) fn)

let run_protected ?name fn =
  let ctx = Effect.perform Cancel.Get_context in
  Cancel.with_cc ~ctx ~parent:ctx.cancel_context ~protected:true Switch @@ fun cancel ->
  Option.iter (Trace.name cancel.id) name;
  run_internal (create cancel) fn

(* Run [fn ()] in [t]'s cancellation context.
   This prevents [t] from finishing until [fn] is done,
   and means that cancelling [t] will cancel [fn]. *)
let run_in t fn =
  with_op t @@ fun () ->
  let ctx = Effect.perform Cancel.Get_context in
  let old_cc = ctx.cancel_context in
  Cancel.move_fiber_to t.cancel ctx;
  match fn () with
  | ()           -> Cancel.move_fiber_to old_cc ctx;
  | exception ex -> Cancel.move_fiber_to old_cc ctx; raise ex

exception Release_error of string * exn

let () =
  Printexc.register_printer (function
      | Release_error (msg, ex) -> Some (Fmt.str "@[<v2>%s@,while handling %a@]" msg Exn.pp ex)
      | _ -> None
    )

let on_release_full t fn =
  if Domain.self () = t.cancel.domain then (
    match t.cancel.state with
    | On | Cancelling _ -> Lwt_dllist.add_r fn t.on_release
    | Finished ->
      match Cancel.protect fn with
      | () -> invalid_arg "Switch finished!"
      | exception ex ->
        let bt = Printexc.get_raw_backtrace () in
        Printexc.raise_with_backtrace (Release_error ("Switch finished!", ex)) bt
  ) else (
    match Cancel.protect fn with
    | () -> invalid_arg "Switch accessed from wrong domain!"
    | exception ex ->
      let bt = Printexc.get_raw_backtrace () in
      Printexc.raise_with_backtrace (Release_error ("Switch accessed from wrong domain!", ex)) bt
  )

let on_release t fn =
  ignore (on_release_full t fn : _ Lwt_dllist.node)

let on_release_cancellable t fn =
  Hook (t.cancel.domain, on_release_full t fn)