From aee62bc935c3c0f3193ac48324317adb55af3c49 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Mon, 18 Nov 2024 21:26:29 +0100 Subject: [PATCH] Untested: synchronization for routines --- arrayjit/lib/anatomy_of_a_backend.md | 37 ++++++++++++++++++------ arrayjit/lib/backend_intf.ml | 7 +++-- arrayjit/lib/backends.ml | 43 +++++++++++++++++++--------- arrayjit/lib/task.ml | 10 +++++++ 4 files changed, 72 insertions(+), 25 deletions(-) diff --git a/arrayjit/lib/anatomy_of_a_backend.md b/arrayjit/lib/anatomy_of_a_backend.md index d7c97a77..d1233348 100644 --- a/arrayjit/lib/anatomy_of_a_backend.md +++ b/arrayjit/lib/anatomy_of_a_backend.md @@ -141,19 +141,38 @@ When using the default stream, CUDA would predictably write to the standard outp ## Synchronization and data transfers -OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization, using the fields `reader_streams` and `writer_streams` of the device record, and `updating_for` of the stream record. +OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization. 1/3rd of the `device` fields have to do with synchronization: + +```ocaml + mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option; + (** The tensor node that was most recently scheduled to be in the cross-stream merge buffer, + and its readiness event. *) + shared_writer_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been scheduled to update (write to) a + cross-stream-shared node, and the associated update completion event. The completed events + are removed opportunistically. *) + host_reading_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been reading from a node's on-host array. The + completed events are removed opportunistically. *) + host_writing_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been writing to a node's on-host array. The completed + events are removed opportunistically. *) +``` + +and 1/3rd of the stream fields also: ```ocaml -... - writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; - (** The streams that most recently have been scheduled to update (write to) the node, and the - associated update completion event. The completed events are removed opportunistically. *) - reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; - (** The streams that most recently have been reading from the node, and the associated use - completion events. The completed events are removed opportunistically. *) -... updating_for : 'event Hashtbl.M(Tnode).t; (* The completion event for updating (writing to) a node via this stream, if any. *) + mutable updating_for_merge_buffer : (Tnode.t * 'event) option; + (** Like {!field-updating_for}, but for the merge buffer. *) + reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams, other than this stream, that most recently have been reading from a node in + this stream's context, and the associated use completion events. The completed events are + removed opportunistically. *) ``` Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading. diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index c50b90b6..f5cc18e0 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -85,8 +85,9 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = { mutable shared_merge_buffer : 'buffer_ptr buffer option; (** Depending on backend implementations, either the currently used cross-stream merge buffer, or the one most recently scheduled. *) - mutable scheduled_shared_merge_node : Tnode.t option; - (** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *) + mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option; + (** The tensor node that was most recently scheduled to be in the cross-stream merge buffer, + and its readiness event. *) mutable latest_stream_id : int; released : Utils.atomic_bool; cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; @@ -123,7 +124,7 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = { stream_id : int; (** An ID unique within the device. *) mutable allocated_buffer : 'buffer_ptr buffer option; updating_for : 'event Hashtbl.M(Tnode).t; - (* The completion event for updating (writing to) a node via this stream, if any. *) + (* The completion event for updating (writing to) a node via this stream, if any. *) mutable updating_for_merge_buffer : (Tnode.t * 'event) option; (** Like {!field-updating_for}, but for the merge buffer. *) reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index acb5d8ac..f1d42f8a 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -33,12 +33,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin let wait_for_ready ~dst ~src tn = let s = src.stream in let d = dst.stream in + (* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *) Hashtbl.find s.updating_for tn |> Option.iter ~f:(fun upd_e -> if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e) - let update_writer_event ?from s tn = - let e = Backend.all_work s in + let update_writer_event ?e ?from s tn = + let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in let f l = (s, e) :: Option.value ~default:[] l in (match (from, tn) with | None, _ -> () @@ -102,15 +103,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr ~src); update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn; - [%log - "copying", - Tn.debug_name tn, - "from", - name_of src, - "at", - (s_arr : Backend.buffer_ptr), - "to", - (d_arr : Backend.buffer_ptr)]; + [%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst]; true) | Streaming when same_device -> Backend.( @@ -123,6 +116,26 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn; [%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src]; true) + + let%track3_l_sexp sync_routine r = + let s = r.context.stream in + let pre () = + Hashtbl.iteri s.device.shared_writer_streams ~f:(fun ~key ~data -> + if Set.mem r.inputs key then + List.iter data ~f:(fun (work_stream, e) -> + if not (equal_stream work_stream s) then Backend.will_wait_for r.context e)); + if r.merge_buffer_input then + Option.iter s.device.scheduled_shared_merge_node ~f:(fun (shared_tn, e) -> + match (s.scheduled_merge_node, e) with + | Some merge_tn, Some e -> + if Tn.equal shared_tn merge_tn then Backend.will_wait_for r.context e + | _ -> ()) + in + let post () = + let e = Backend.all_work s in + Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn) + in + { r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) } end let lower_assignments ?name bindings asgns = @@ -397,7 +410,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct Task.prepend schedule ~work:(fun () -> check_merge_buffer context.stream ~code_node:code.expected_merge_node) in - { context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs } + sync_routine + { context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs } let%debug3_sexp link_batch context code_batch = verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays @@ -423,7 +437,10 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct Task.prepend schedule ~work:(fun () -> check_merge_buffer context.stream ~code_node:expected_merge_node) in - (context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs })) + let r = + sync_routine { context; schedule; bindings; name; inputs; merge_buffer_input; outputs } + in + (context, Some r)) end module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend)) diff --git a/arrayjit/lib/task.ml b/arrayjit/lib/task.ml index d9194eb7..d2426b48 100644 --- a/arrayjit/lib/task.ml +++ b/arrayjit/lib/task.ml @@ -27,6 +27,16 @@ let prepend ~work (Task task) = task.work ()); } +let append ~work (Task task) = + Task + { + task with + work = + (fun () -> + task.work (); + work ()); + } + let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream (Task { description; _ } as task) = [%log_result "enschedule", description, "on", get_stream_name stream];