Skip to content

Commit

Permalink
Get rid of special-casing event creation; proper syncing for `from_ho…
Browse files Browse the repository at this point in the history
…st` and `to_host`
  • Loading branch information
lukstafi committed Nov 15, 2024
1 parent fb04bc0 commit 9041153
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 83 deletions.
25 changes: 9 additions & 16 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,19 @@ When using the default stream, CUDA would predictably write to the standard outp

## Synchronization and data transfers

OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization, using the fields `stream_working_on` of the device record, and `queried_work_for` of the stream record.
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization, using the fields `reader_streams` and `writer_streams` of the device record, and `updating_for` of the stream record.

```ocaml
...
stream_working_on : (int * 'event) option Hashtbl.M(Tnode).t;
(** The stream that most recently has been updating the node, and the associated update
completion event. An entry for a tensor node is only populated when
{!field-queried_work_for} is also populated. *)
writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) the node, and the
associated update completion event. The completed events are removed opportunistically. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from the node, and the associated use
completion events. The completed events are removed opportunistically. *)
...
queried_work_for : 'event option Hashtbl.M(Tnode).t;
(* The completion event for updating the node via this stream, if any. Only existing entries
are updated, and an entry is populated when {!work_for} is called for the first time on the
tensor node. *)
...
val work_for : context -> Tnode.t -> event option
(** If the tensor node is in the context, returns the event indicating if currently running or
scheduled computations modifying that node on the context's stream have completed.
NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work
tracking for [tn] and return the [all_work] event for [ctx]'s stream. *)
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
```

Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading.
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_streams = Hashtbl.create (module Tnode);
writer_stream = Hashtbl.create (module Tnode);
writer_streams = Hashtbl.create (module Tnode);
reader_streams = Hashtbl.create (module Tnode);
}

Expand All @@ -119,7 +119,7 @@ struct
scheduled_merge_node = None;
stream_id;
allocated_buffer = None;
queried_work_for = Hashtbl.create (module Tnode);
updating_for = Hashtbl.create (module Tnode);
}

let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
Expand Down
34 changes: 10 additions & 24 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = {
Currently, if the memory mode of a node is inferred, only this stream will modify a
cross-stream shared array. But memory modes can also be set manually. *)
writer_stream : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t;
(** The stream that most recently has been updating (writing to) the node, and the associated
update completion event. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t;
writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) the node, and the
associated update completion event. The completed events are removed opportunistically. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from the node, and the associated use
completion events. An entry is only populated for cross-stream shared nodes, and an event
only populated when multiple streams worked with the node. *)
completion events. The completed events are removed opportunistically. *)
}
[@@deriving sexp_of]

Expand All @@ -116,11 +115,8 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *)
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
queried_work_for : 'event option Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. An entry
is populated when {!work_for} is called for the first time on the tensor node, or
{!field-writer_stream} needs to be updated with this stream. Otherwise, only existing
entries are udpated. *)
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -247,24 +243,13 @@ module type With_buffer_retrieval_and_syncing = sig
type context
type event

val work_for : context -> Tnode.t -> event option
(** If the tensor node is in the context, returns the event indicating if currently running or
scheduled computations modifying that node on the context's stream have completed.
NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work
tracking for [tn] and return the [all_work] event for [ctx]'s stream. *)

val from_host : context -> Tnode.t -> bool
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
the stream (via [await ctx.stream] or [sync (work_for ctx tn)]) before the host's data is
overwritten. *)
returns true, otherwise returns false. *)

val to_host : context -> Tnode.t -> bool
(** If the tensor node is both hosted and in-context, schedules a copy from context to host and
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
the stream (via [await ctx.stream] or [sync (work_for ctx tn)]) before the host's data is
read. *)
returns true, otherwise returns false. *)

val device_to_device :
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
Expand All @@ -282,6 +267,7 @@ module type With_buffer_retrieval_and_syncing = sig
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
(* FIXME: udpate the syncing comment. *)
end

module type Backend = sig
Expand Down
63 changes: 22 additions & 41 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,60 +21,41 @@ let check_merge_buffer stream ~code_node =
^ ", expected by code: " ^ name code_node)

module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
let work_for context tn =
let stream = context.stream in
let default () = Some (Backend.all_work stream) in
if not @@ Map.mem context.ctx_arrays tn then None
else
Hashtbl.update_and_return stream.queried_work_for tn ~f:(function
| None | Some None -> default ()
| Some (Some _ as event) -> event)

let wait_for_users ctx tn =
let s = ctx.stream in
let worked_multi_streams =
Hashtbl.find s.device.writer_stream tn
|> Option.map ~f:snd |> Option.join |> Option.is_some |> ref
in
Hashtbl.find s.device.writer_stream tn
|> Option.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then (
worked_multi_streams := true;
Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream)));
!worked_multi_streams

let wait_for_writers ctx tn =
let wait_for_all ctx streams tn =
let s = ctx.stream in
Hashtbl.find s.device.writer_stream tn
|> Option.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then
Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream))

let update_writer_event ~worked_multi_streams s tn =
if Hashtbl.mem s.queried_work_for tn || worked_multi_streams then (
let e = Backend.all_work s in
Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, Some e));
Hashtbl.update s.queried_work_for tn ~f:(Option.map ~f:(fun _ -> e)))
else Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, None))
Hashtbl.update_and_return streams tn
~f:
(Fn.compose (List.filter ~f:(fun (_, e) -> not (Backend.is_done e)))
@@ Option.value ~default:[])
|> List.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)

let update_writer_event s tn =
let e = Backend.all_work s in
Hashtbl.update s.device.writer_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l);
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)

let add_reader s tn =
let e = Backend.all_work s in
Hashtbl.update s.device.reader_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l)

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
((* Wait for all users of the array before copying. *)
let worked_multi_streams = wait_for_users ctx tn in
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~worked_multi_streams ctx.stream tn);
wait_for_all ctx ctx.stream.device.reader_streams tn;
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ctx.stream tn;
true
| _ -> false

let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
(* Only wait for writers of the array before copying. *)
wait_for_writers ctx tn;
wait_for_all ctx ctx.stream.device.writer_streams tn;
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
Backend.to_host ~src_ptr:src ~src:ctx hosted;
add_reader ctx.stream tn;
true
| _ -> false

Expand Down

0 comments on commit 9041153

Please sign in to comment.