diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index 61de5d60..aa816a5d 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -121,6 +121,7 @@ struct stream_id; allocated_buffer = None; updating_for = Hashtbl.create (module Tnode); + updating_for_merge_buffer = None; reader_streams = Hashtbl.create (module Tnode); } diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index 9a37a90d..c50b90b6 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -57,6 +57,7 @@ type 'context routine = { inputs : Set.M(Tnode).t; (** The materialized read-only and read-before-write (within the routine) non-constant nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters. *) + merge_buffer_input : bool; (** Similar to {!field-inputs}, for the merge buffer. *) outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *) } [@@deriving sexp_of] @@ -123,6 +124,8 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = { mutable allocated_buffer : 'buffer_ptr buffer option; updating_for : 'event Hashtbl.M(Tnode).t; (* The completion event for updating (writing to) a node via this stream, if any. *) + mutable updating_for_merge_buffer : (Tnode.t * 'event) option; + (** Like {!field-updating_for}, but for the merge buffer. *) reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; (** The streams, other than this stream, that most recently have been reading from a node in this stream's context, and the associated use completion events. The completed events are @@ -274,7 +277,7 @@ module type With_buffer_retrieval_and_syncing = sig - If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for streaming. - If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s - stream, and registers [dst.stream] with a reader event for the node. + stream, and updates the writer event for the merge buffer. NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge buffer but before scheduling work on [src] that modifies [tn], execute diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 27d58463..acb5d8ac 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -30,23 +30,32 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin |> List.iter ~f:(fun (work_stream, e) -> if not (equal_stream work_stream s) then Backend.will_wait_for ctx e) - let update_writer_event ?(from_host = false) s tn = - let e = Backend.all_work s in - if from_host then - Hashtbl.update s.device.host_writing_streams tn ~f:(fun l -> - (s, e) :: Option.value ~default:[] l); - (* To be on the safe side, record events for potentially cross-stream nodes. *) - if Tn.potentially_cross_stream tn then - Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l -> - (s, e) :: Option.value ~default:[] l); - Hashtbl.update s.updating_for tn ~f:(fun _ -> e) - - let add_reader s tn from = + let wait_for_ready ~dst ~src tn = + let s = src.stream in + let d = dst.stream in + Hashtbl.find s.updating_for tn + |> Option.iter ~f:(fun upd_e -> + if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e) + + let update_writer_event ?from s tn = let e = Backend.all_work s in let f l = (s, e) :: Option.value ~default:[] l in - match from with - | `Host -> Hashtbl.update s.device.host_reading_streams tn ~f - | `Src src -> Hashtbl.update src.reader_streams tn ~f + (match (from, tn) with + | None, _ -> () + | Some `Host, Assignments.(Node tn | Merge_buffer tn) -> + Hashtbl.update s.device.host_reading_streams tn ~f + | Some (`Src src), (Node tn | Merge_buffer tn) -> Hashtbl.update src.reader_streams tn ~f); + (* To be on the safe side, record events for potentially cross-stream nodes. *) + match tn with + | Node tn -> + if Tn.potentially_cross_stream tn then + Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l -> + (s, e) :: Option.value ~default:[] l); + Hashtbl.update s.updating_for tn ~f:(fun _ -> e) + | Merge_buffer tn -> + Option.iter s.updating_for_merge_buffer ~f:(fun (_old_tn, old_e) -> + assert (Backend.is_done old_e)); + s.updating_for_merge_buffer <- Some (tn, e) let%diagn2_l_sexp from_host (ctx : Backend.context) tn = match (tn, Map.find ctx.ctx_arrays tn) with @@ -54,8 +63,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin wait_for_all ctx ctx.stream.reader_streams tn; [%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"]; Backend.from_host ~dst_ptr:dst ~dst:ctx hosted; - update_writer_event ~from_host:true ctx.stream tn; - add_reader ctx.stream tn @@ `Host; + update_writer_event ~from:`Host ctx.stream @@ Node tn; true | _ -> false @@ -66,6 +74,10 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin wait_for_all ctx ctx.stream.device.shared_writer_streams tn; [%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"]; Backend.to_host ~src_ptr:src ~src:ctx hosted; + let s = ctx.stream in + let e = Backend.all_work s in + Hashtbl.update s.device.host_writing_streams tn ~f:(fun l -> + (s, e) :: Option.value ~default:[] l); true | _ -> false @@ -80,6 +92,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin match Map.find src.ctx_arrays tn with | None -> false | Some s_arr -> ( + wait_for_ready ~dst ~src tn; match into_merge_buffer with | No -> ( match Map.find dst.ctx_arrays tn with @@ -88,8 +101,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin Backend.( device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr ~src); + update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn; [%log - "copied", + "copying", Tn.debug_name tn, "from", name_of src, @@ -106,7 +120,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin | Copy | Streaming -> Backend.( device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src); - [%log "copied into merge buffer", Tn.debug_name tn, "from", name_of src]; + update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn; + [%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src]; true) end @@ -371,7 +386,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct let%debug3_sexp link context (code : code) = verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays ~from_prior_context:code.from_prior_context; - let inputs, outputs = Low_level.input_and_output_nodes code.lowered in + let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes code.lowered in let ctx_arrays = Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays ~f:(alloc_if_needed context.stream) @@ -382,7 +397,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct Task.prepend schedule ~work:(fun () -> check_merge_buffer context.stream ~code_node:code.expected_merge_node) in - { context; schedule; bindings; name = code.name; inputs; outputs } + { context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs } let%debug3_sexp link_batch context code_batch = verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays @@ -401,14 +416,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct let ctx_arrays = Option.value_exn ctx_arrays.(i) in let context = make_child ~ctx_arrays context in let expected_merge_node = code_batch.expected_merge_nodes.(i) in - let inputs, outputs = + let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes @@ Option.value_exn code_batch.lowereds.(i) in let schedule = Task.prepend schedule ~work:(fun () -> check_merge_buffer context.stream ~code_node:expected_merge_node) in - (context, Some { context; schedule; bindings; name; inputs; outputs })) + (context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs })) end module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend)) diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index 2960c43c..4c577211 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -738,22 +738,25 @@ type optimized = { traced_store : traced_store; llc : t; merge_node : Tn.t optio [@@deriving sexp_of] let input_and_output_nodes optimized = - Hashtbl.fold optimized.traced_store - ~init:(Set.empty (module Tn), Set.empty (module Tn)) - ~f:(fun ~key ~data (inputs, outputs) -> - let materialized = Tn.is_materialized_force key 50 in - let inputs = - if - materialized && (not (Tn.known_constant key)) && (data.read_only || data.read_before_write) - then Set.add inputs key - else inputs - in - let outputs = - if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then - Set.add outputs key - else outputs - in - (inputs, outputs)) + ( Hashtbl.fold optimized.traced_store + ~init:(Set.empty (module Tn), Set.empty (module Tn)) + ~f:(fun ~key ~data (inputs, outputs) -> + let materialized = Tn.is_materialized_force key 50 in + let inputs = + if + materialized + && (not (Tn.known_constant key)) + && (data.read_only || data.read_before_write) + then Set.add inputs key + else inputs + in + let outputs = + if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then + Set.add outputs key + else outputs + in + (inputs, outputs)), + Option.is_some optimized.merge_node ) let%diagn2_sexp optimize_proc static_indices llc = let traced_store = Hashtbl.create (module Tnode) in diff --git a/arrayjit/lib/low_level.mli b/arrayjit/lib/low_level.mli index 91aeaf20..1bcdd951 100644 --- a/arrayjit/lib/low_level.mli +++ b/arrayjit/lib/low_level.mli @@ -99,7 +99,7 @@ val optimize : t -> optimized -val input_and_output_nodes : optimized -> Set.M(Tnode).t * Set.M(Tnode).t +val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * bool (** Inputs are the materialized read-only and read-before-write (within the code) non-constant nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.