-
Notifications
You must be signed in to change notification settings - Fork 4
/
threadpool.ml
418 lines (337 loc) · 10 KB
/
threadpool.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
(*
open Types;;
(* These should be defined elsewhere *)
let num_threads = detected_processors;;
*)
(****
type 'a wait_internal_t =
| Waiting
| Done of 'a
;;
type 'a wait_t = {
mutable w : 'a wait_internal_t;
};;
type slot_internal_t =
| Full
| Processed
;;
type slot_t = {
sw : unit -> unit;
sc : bool ref;
};;
let list_mutex = Mutex.create ();;
let list_length_ref = ref 0;;
let list_ref = ref [];;
let list_signal_if_added = Condition.create ();;
(* This clears out all the things where !sc is true *)
let clean_list_in_mutex () =
(* Printf.printf "%s%!" (Printf.sprintf "Cleaner starts with %d things\n" !list_length_ref);*)
let rec keep_rebuilding = function
| hd :: tl when !(hd.sc) -> (
decr list_length_ref;
keep_rebuilding tl
)
| hd :: tl -> hd :: keep_rebuilding tl
| [] -> []
in
list_ref := keep_rebuilding !list_ref;
(* Printf.printf "%s%!" (Printf.sprintf "Cleaner ends with %d things\n" !list_length_ref);*)
;;
let (thread_pool, first_worker_thread_id, last_worker_thread_id) =
let first_thread_id_ref = ref max_int in
let last_thread_id_ref = ref min_int in
let num_threads_counted_ref = ref 0 in
let thread_id_mutex = Mutex.create () in
let thread_guts : int -> unit = fun thread_i ->
(* The thread IDs are not guaranteed to be consecutive, but if they aren't it won't be a problem *)
Mutex.lock thread_id_mutex;
first_thread_id_ref := min !first_thread_id_ref (Thread.id (Thread.self ()));
last_thread_id_ref := max !last_thread_id_ref (Thread.id (Thread.self ()));
incr num_threads_counted_ref;
Mutex.unlock thread_id_mutex;
Thread.yield ();
let rec loop () =
Mutex.lock list_mutex;
(* Printf.printf "%s%!" (Printf.sprintf "Loop sees %d things to be done\n" !list_length_ref);*)
let rec keep_waiting () = match !list_ref with
| hd :: tl when !(hd.sc) -> (
(* Check really fast to see if this is done without unlocking the mutex *)
list_ref := tl;
decr list_length_ref;
keep_waiting ()
)
| hd :: [] -> (list_ref := []; decr list_length_ref; hd)
| hd :: tl -> (list_ref := tl; decr list_length_ref; Condition.signal list_signal_if_added; hd)
| [] -> (Condition.wait list_signal_if_added list_mutex; keep_waiting ())
in
let got = keep_waiting () in
Mutex.unlock list_mutex;
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d doing job\n" (Thread.id (Thread.self ())));*)
got.sw ();
loop ()
in
loop ()
in
let thread_pool = Array.init num_threads (Thread.create thread_guts) in
(* Wait for all the threads to initialize *)
let rec keep_waiting () =
Mutex.lock thread_id_mutex;
let keep_going = (!num_threads_counted_ref < num_threads) in
Mutex.unlock thread_id_mutex;
if keep_going then (
Thread.yield ();
keep_waiting ()
)
in
keep_waiting ();
Printf.printf "The threads are from %d to %d\n%!" !first_thread_id_ref !last_thread_id_ref;
(thread_pool, !first_thread_id_ref, !last_thread_id_ref)
;;
let thread_might_be_worker () =
(*
let id = Thread.id (Thread.self ()) in
id >= first_worker_thread_id && id <= last_worker_thread_id
*)
false
;;
let wrap f x w =
let m = Mutex.create () in
let c = Condition.create () in
let done_ref = ref false in
let start_processing () =
match w.w with
| Done _ -> () (* Already run; skip this *)
| Waiting -> (
if Mutex.try_lock m then (
(match w.w with
| Done _ -> ()
| Waiting -> (
(* Do it! *)
w.w <- Done (f x);
done_ref := true;
Condition.broadcast c;
)
);
Mutex.unlock m;
()
) else (
(* Something else has the lock; it must be processing *)
()
)
)
in
let wait_for_processing () = match w.w with
| Done out -> out
| Waiting -> (
if thread_might_be_worker () then (
(* Since this may be a worker, it shouldn't block on just this work item *)
(* Instead try to find something else in the list *)
let rec keep_trying_lock () =
if Mutex.try_lock m then (
let ret = match w.w with
| Done out -> out
| Waiting -> (
let ret = f x in
w.w <- Done ret;
done_ref := true;
Condition.broadcast c;
ret
)
in
Mutex.unlock m;
ret
) else (
(* We can't get the lock but this MAY be a worker; check to see if there's something else to do in the meantime *)
Mutex.lock list_mutex;
(* Printf.printf "%s%!" (Printf.sprintf "Stealing sees %d things to be done\n" !list_length_ref);*)
clean_list_in_mutex ();
let do_something = match !list_ref with
| hd :: [] -> (list_ref := []; decr list_length_ref; Some hd)
| hd :: tl -> (list_ref := tl; decr list_length_ref; Condition.signal list_signal_if_added; Some hd)
| [] -> None
in
Mutex.unlock list_mutex;
(match do_something with
| Some got -> (
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d doing a side job\n" (Thread.id (Thread.self ())));*)
got.sw ();
keep_trying_lock ()
)
| None -> (
(* Just lock it *)
Mutex.lock m;
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d found nothing to do; finishing job\n" (Thread.id (Thread.self ())));*)
let ret = match w.w with
| Done out -> out
| Waiting -> (
let ret = f x in
w.w <- Done ret;
done_ref := true;
Condition.broadcast c;
ret
)
in
Mutex.unlock m;
ret
)
)
)
in
keep_trying_lock ()
) else (
(* The thread is not a worker and we just need to wait on the mutex *)
(* Printf.printf "%s%!" (Printf.sprintf "Non-worker thread %d; block in wait_for_processing\n" (Thread.id (Thread.self ())));*)
Mutex.lock m;
let ret = match w.w with
| Done out -> out
| Waiting -> (
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d stealing\n" (Thread.id (Thread.self ())));*)
let ret = f x in
w.w <- Done ret;
done_ref := true;
Condition.broadcast c;
ret
)
in
Mutex.unlock m;
ret
)
)
in
(done_ref, start_processing, wait_for_processing)
;;
let send_unsafe f x =
let w = {w = Waiting} in
let (done_ref, start_processing, wait_for_processing) = wrap f x w in
let slot = {sc = done_ref; sw = start_processing} in
Mutex.lock list_mutex;
let signal = match !list_ref with
| [] -> true
| _ -> false
in
list_ref := slot :: !list_ref;
incr list_length_ref;
if signal then Condition.signal list_signal_if_added;
(* DELETEME *)
if List.length !list_ref <> !list_length_ref then Printf.printf "%s%!" (Printf.sprintf "%d <> %d!\n" (List.length !list_ref) !list_length_ref);
Mutex.unlock list_mutex;
(* Thread.yield ();*)
wait_for_processing
;;
let send f x =
let f2 x = try
Normal (f x)
with
e -> Error e
in
send_unsafe f2 x
;;
****)
type ('a,'b) obj_slot_t = {
slot_parameters : 'a;
slot_mutex : Mutex.t;
mutable slot_ret : 'b option;
};;
let slot_is_done = function
| {slot_ret = None} -> false
| _ -> true
;;
class ['a,'b] per_function (f : 'a -> 'b) max_threads =
object
method f = f
method max_threads = max_threads
val l_mutex = Mutex.create ()
val l_signal_if_added = Condition.create ()
val mutable l = List2.create ()
val mutable thread_pool = [||]
method send params = (
let slot = {
slot_parameters = params;
slot_mutex = Mutex.create ();
slot_ret = None;
} in
Mutex.lock l_mutex;
let signal = List2.is_empty l in
List2.prepend l slot;
if signal then Condition.broadcast l_signal_if_added;
Mutex.unlock l_mutex;
Thread.yield ();
slot
)
method send_last params = (
let slot = {
slot_parameters = params;
slot_mutex = Mutex.create ();
slot_ret = None;
} in
Mutex.lock l_mutex;
let signal = List2.is_empty l in
List2.append l slot;
if signal then Condition.broadcast l_signal_if_added;
Mutex.unlock l_mutex;
Thread.yield ();
slot
)
method recv slot = (
(* It would be sort of rare for this to be recursive, so don't try to check out anything else *)
Mutex.lock slot.slot_mutex;
let ret = match slot.slot_ret with
| Some x -> x
| None -> (
(* Printf.printf "Calculate it ourselves...\n%!";*)
let ret = f slot.slot_parameters in
slot.slot_ret <- Some ret;
ret
)
in
Mutex.unlock slot.slot_mutex;
ret
)
(* If actual processing is not needed *)
val bypass_mutex = Mutex.create ()
method bypass (input : 'a) (output : 'b) = (
{
slot_parameters = input;
slot_mutex = bypass_mutex;
slot_ret = Some output;
}
)
initializer (
(* Initialize the threads *)
let rec l_keep_waiting () = match List2.take_first_perhaps l with
| Some hd -> (
if slot_is_done hd then (
l_keep_waiting ()
) else (
if List2.is_empty l then Condition.signal l_signal_if_added;
hd
)
)
| None -> (
Condition.wait l_signal_if_added l_mutex;
l_keep_waiting ()
)
in
let thread_guts : int -> unit = fun _ ->
let rec loop () =
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d looping\n" (Thread.id (Thread.self ())));*)
Mutex.lock l_mutex;
let got = l_keep_waiting () in
Mutex.unlock l_mutex;
if Mutex.try_lock got.slot_mutex then (
(* Printf.printf "%s%!" (Printf.sprintf "Thread %d got something\n" (Thread.id (Thread.self ())));*)
(match got.slot_ret with
| None -> got.slot_ret <- Some (f got.slot_parameters);
| Some _ -> ()
);
Mutex.unlock got.slot_mutex;
()
); (* Otherwise just ignore it *)
loop ()
in
loop ()
in
thread_pool <- Array.init max_threads (Thread.create thread_guts)
)
end
;;