package moonpool

  1. Overview
  2. Docs

Source file ws_pool.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
open Types_
module WSQ = Ws_deque_
module A = Atomic_
module TLS = Thread_local_storage_
include Runner

let ( let@ ) = ( @@ )

module Id = struct
  type t = unit ref
  (** Unique identifier for a pool *)

  let create () : t = Sys.opaque_identity (ref ())
  let equal : t -> t -> bool = ( == )
end

type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task

type task_full =
  | T_start of {
      ls: Task_local_storage.t;
      f: task;
    }
  | T_resume : {
      ls: Task_local_storage.t;
      k: 'a -> unit;
      x: 'a;
    }
      -> task_full

type worker_state = {
  pool_id_: Id.t;  (** Unique per pool *)
  mutable thread: Thread.t;
  q: task_full WSQ.t;  (** Work stealing queue *)
  mutable cur_ls: Task_local_storage.t option;  (** Task storage *)
  rng: Random.State.t;
}
(** State for a given worker. Only this worker is
    allowed to push into the queue, but other workers
    can come and steal from it if they're idle. *)

type state = {
  id_: Id.t;
  active: bool A.t;  (** Becomes [false] when the pool is shutdown. *)
  workers: worker_state array;  (** Fixed set of workers. *)
  main_q: task_full Queue.t;
      (** Main queue for tasks coming from the outside *)
  mutable n_waiting: int; (* protected by mutex *)
  mutable n_waiting_nonzero: bool;  (** [n_waiting > 0] *)
  mutex: Mutex.t;
  cond: Condition.t;
  on_exn: exn -> Printexc.raw_backtrace -> unit;
  around_task: around_task;
}
(** internal state *)

let[@inline] size_ (self : state) = Array.length self.workers

let num_tasks_ (self : state) : int =
  let n = ref 0 in
  n := Queue.length self.main_q;
  Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
  !n

(** TLS, used by worker to store their specific state
    and be able to retrieve it from tasks when we schedule new
    sub-tasks. *)
let k_worker_state : worker_state option ref TLS.key =
  TLS.new_key (fun () -> ref None)

let[@inline] find_current_worker_ () : worker_state option =
  !(TLS.get k_worker_state)

(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
  if self.n_waiting_nonzero then (
    Mutex.lock self.mutex;
    Condition.signal self.cond;
    Mutex.unlock self.mutex
  )

(** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~w (task : task_full) : unit =
  (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
  match w with
  | Some w when Id.equal self.id_ w.pool_id_ ->
    (* we're on this same pool, schedule in the worker's state. Otherwise
       we might also be on pool A but asking to schedule on pool B,
       so we have to check that identifiers match. *)
    let pushed = WSQ.push w.q task in
    if pushed then
      try_wake_someone_ self
    else (
      (* overflow into main queue *)
      Mutex.lock self.mutex;
      Queue.push task self.main_q;
      if self.n_waiting_nonzero then Condition.signal self.cond;
      Mutex.unlock self.mutex
    )
  | _ ->
    if A.get self.active then (
      (* push into the main queue *)
      Mutex.lock self.mutex;
      Queue.push task self.main_q;
      if self.n_waiting_nonzero then Condition.signal self.cond;
      Mutex.unlock self.mutex
    ) else
      (* notify the caller that scheduling tasks is no
         longer permitted *)
      raise Shutdown

(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
    : unit =
  (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
  let (AT_pair (before_task, after_task)) = self.around_task in

  let ls =
    match task with
    | T_start { ls; _ } | T_resume { ls; _ } -> ls
  in

  w.cur_ls <- Some ls;
  TLS.get k_cur_storage := Some ls;
  let _ctx = before_task runner in

  let[@inline] on_suspend () : _ ref =
    match find_current_worker_ () with
    | Some { cur_ls = Some w; _ } -> w
    | _ -> assert false
  in

  let run_another_task ls (task' : task) =
    let w =
      match find_current_worker_ () with
      | Some w when Id.equal w.pool_id_ self.id_ -> Some w
      | _ -> None
    in
    let ls' = Task_local_storage.Direct.copy ls in
    schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
  in

  let resume ls k x =
    let w =
      match find_current_worker_ () with
      | Some w when Id.equal w.pool_id_ self.id_ -> Some w
      | _ -> None
    in
    schedule_task_ self ~w @@ T_resume { ls; k; x }
  in

  (* run the task now, catching errors *)
  (try
     match task with
     | T_start { f = task; _ } ->
       (* run [task()] and handle [suspend] in it *)
       Suspend_.with_suspend
         (WSH { on_suspend; run = run_another_task; resume })
         task
     | T_resume { k; x; _ } ->
       (* this is already in an effect handler *)
       k x
   with e ->
     let bt = Printexc.get_raw_backtrace () in
     self.on_exn e bt);

  after_task runner _ctx;
  w.cur_ls <- None;
  TLS.get k_cur_storage := None

let run_async_ (self : state) ~ls (f : task) : unit =
  let w = find_current_worker_ () in
  schedule_task_ self ~w @@ T_start { f; ls }

(* TODO: function to schedule many tasks from the outside.
    - build a queue
    - lock
    - queue transfer
    - wakeup all (broadcast)
    - unlock *)

(** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_ (self : state) : unit =
  self.n_waiting <- self.n_waiting + 1;
  if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
  Condition.wait self.cond self.mutex;
  self.n_waiting <- self.n_waiting - 1;
  if self.n_waiting = 0 then self.n_waiting_nonzero <- false

exception Got_task of task_full

(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
    =
  let init = Random.State.int w.rng (Array.length self.workers) in

  try
    for i = 0 to Array.length self.workers - 1 do
      let w' =
        Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
      in

      if w != w' then (
        match WSQ.steal w'.q with
        | Some t -> raise_notrace (Got_task t)
        | None -> ()
      )
    done;
    None
  with Got_task t -> Some t

(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
  let continue = ref true in
  while !continue && A.get self.active do
    match WSQ.pop w.q with
    | Some task ->
      try_wake_someone_ self;
      run_task_now_ self ~runner ~w task
    | None -> continue := false
  done

(** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
  TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
  TLS.get k_worker_state := Some w;

  let rec main () : unit =
    worker_run_self_tasks_ self ~runner w;
    try_steal ()
  and run_task task : unit =
    run_task_now_ self ~runner ~w task;
    main ()
  and try_steal () =
    match try_to_steal_work_once_ self w with
    | Some task -> run_task task
    | None -> wait ()
  and wait () =
    Mutex.lock self.mutex;
    match Queue.pop self.main_q with
    | task ->
      Mutex.unlock self.mutex;
      run_task task
    | exception Queue.Empty ->
      (* wait here *)
      if A.get self.active then (
        wait_ self;

        (* see if a task became available *)
        let task =
          try Some (Queue.pop self.main_q) with Queue.Empty -> None
        in
        Mutex.unlock self.mutex;

        match task with
        | Some t -> run_task t
        | None -> try_steal ()
      ) else
        (* do nothing more: no task in main queue, and we are shutting
           down so no new task should arrive.
           The exception is if another task is creating subtasks
           that overflow into the main queue, but we can ignore that at
           the price of slightly decreased performance for the last few
           tasks *)
        Mutex.unlock self.mutex
  in

  (* handle domain-local await *)
  Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main

let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()

let shutdown_ ~wait (self : state) : unit =
  if A.exchange self.active false then (
    Mutex.lock self.mutex;
    Condition.broadcast self.cond;
    Mutex.unlock self.mutex;
    if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
  )

type ('a, 'b) create_args =
  ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
  ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
  ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
  ?around_task:(t -> 'b) * (t -> 'b -> unit) ->
  ?num_threads:int ->
  ?name:string ->
  'a
(** Arguments used in {!create}. See {!create} for explanations. *)

let dummy_task_ : task_full =
  T_start { f = ignore; ls = Task_local_storage.dummy }

let create ?(on_init_thread = default_thread_init_exit_)
    ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
    ?around_task ?num_threads ?name () : t =
  let pool_id_ = Id.create () in
  (* wrapper *)
  let around_task =
    match around_task with
    | Some (f, g) -> AT_pair (f, g)
    | None -> AT_pair (ignore, fun _ _ -> ())
  in

  let num_domains = Domain_pool_.max_number_of_domains () in
  let num_threads = Util_pool_.num_threads ?num_threads () in

  (* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
  let offset = Random.int num_domains in

  let workers : worker_state array =
    let dummy = Thread.self () in
    Array.init num_threads (fun i ->
        {
          pool_id_;
          thread = dummy;
          q = WSQ.create ~dummy:dummy_task_ ();
          rng = Random.State.make [| i |];
          cur_ls = None;
        })
  in

  let pool =
    {
      id_ = pool_id_;
      active = A.make true;
      workers;
      main_q = Queue.create ();
      n_waiting = 0;
      n_waiting_nonzero = true;
      mutex = Mutex.create ();
      cond = Condition.create ();
      around_task;
      on_exn;
    }
  in

  let runner =
    Runner.For_runner_implementors.create
      ~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
      ~run_async:(fun ~ls f -> run_async_ pool ~ls f)
      ~size:(fun () -> size_ pool)
      ~num_tasks:(fun () -> num_tasks_ pool)
      ()
  in

  (* temporary queue used to obtain thread handles from domains
     on which the thread are started. *)
  let receive_threads = Bb_queue.create () in

  (* start the thread with index [i] *)
  let start_thread_with_idx i =
    let w = pool.workers.(i) in
    let dom_idx = (offset + i) mod num_domains in

    (* function run in the thread itself *)
    let main_thread_fun () : unit =
      let thread = Thread.self () in
      let t_id = Thread.id thread in
      on_init_thread ~dom_id:dom_idx ~t_id ();
      TLS.get k_cur_storage := None;

      (* set thread name *)
      Option.iter
        (fun name ->
          Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i))
        name;

      let run () = worker_thread_ pool ~runner w in

      (* now run the main loop *)
      Fun.protect run ~finally:(fun () ->
          (* on termination, decrease refcount of underlying domain *)
          Domain_pool_.decr_on dom_idx);
      on_exit_thread ~dom_id:dom_idx ~t_id ()
    in

    (* function called in domain with index [i], to
       create the thread and push it into [receive_threads] *)
    let create_thread_in_domain () =
      let thread = Thread.create main_thread_fun () in
      (* send the thread from the domain back to us *)
      Bb_queue.push receive_threads (i, thread)
    in

    Domain_pool_.run_on dom_idx create_thread_in_domain
  in

  (* start all threads, placing them on the domains
     according to their index and [offset] in a round-robin fashion. *)
  for i = 0 to num_threads - 1 do
    start_thread_with_idx i
  done;

  (* receive the newly created threads back from domains *)
  for _j = 1 to num_threads do
    let i, th = Bb_queue.pop receive_threads in
    let worker_state = pool.workers.(i) in
    worker_state.thread <- th
  done;

  runner

let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
    ?name () f =
  let pool =
    create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
      ?name ()
  in
  let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
  f pool