From f571d9e05df8cedd9d1dfc20a38a94123853ed92 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Wed, 20 Nov 2024 19:29:15 +0100 Subject: [PATCH] Fix sexp_of_device/stream to break cyclicity --- arrayjit/lib/backend_intf.ml | 61 +++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index ddea5bc7..4e355fce 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -79,42 +79,78 @@ module type Device_config = sig val name : string end -type ('buffer_ptr, 'dev, 'runner, 'event) device = { + +type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = { + dev : 'dev; + ordinal : int; + mutable shared_merge_buffer : 'buffer_ptr buffer option; + mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option; + mutable latest_stream_id : int; + released : Utils.atomic_bool; + cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; + owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t; + shared_writer_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; + host_reading_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; + host_writing_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; +} + +and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = { + device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref; + runner : 'runner; + merge_buffer : 'buffer_ptr buffer option ref; + mutable scheduled_merge_node : Tnode.t option; + stream_id : int; + mutable allocated_buffer : 'buffer_ptr buffer option; + updating_for : 'event Hashtbl.M(Tnode).t; + mutable updating_for_merge_buffer : (Tnode.t * 'event) option; + reader_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; +} + +let sexp_of_device_ref _ _ _ _ device = [%sexp_of: string * int] ("ordinal", device.ordinal) +let sexp_of_stream_ref _ _ _ _ stream = [%sexp_of: string * int] ("stream_id", stream.stream_id) +let equal_stream_ref s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal + +type ('buffer_ptr, 'dev, 'runner, 'event) device = + ('buffer_ptr, 'dev, 'runner, 'event) device_ref = { dev : 'dev; ordinal : int; mutable shared_merge_buffer : 'buffer_ptr buffer option; (** Depending on backend implementations, either the currently used cross-stream merge buffer, or the one most recently scheduled. *) mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option; - (** The tensor node that was most recently scheduled to be in the cross-stream merge buffer, - and its readiness event. *) + (** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *) mutable latest_stream_id : int; released : Utils.atomic_bool; cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; (** Freshly created arrays that might be shared across streams. The map can both grow and shrink. *) - owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t; + owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t; (** The stream owning a given node. This map can only grow. Currently, if the memory mode of a node is inferred, only this stream will modify a cross-stream shared array. But memory modes can also be set manually. *) shared_writer_streams : - (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; (** The streams that most recently have been scheduled to update (write to) a cross-stream-shared node, and the associated update completion event. The completed events are removed opportunistically. *) host_reading_streams : - (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; (** The streams that most recently have been reading from a node's on-host array. The completed events are removed opportunistically. *) host_writing_streams : - (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; (** The streams that most recently have been writing to a node's on-host array. The completed events are removed opportunistically. *) } [@@deriving sexp_of] -and ('buffer_ptr, 'dev, 'runner, 'event) stream = { - device : ('buffer_ptr, 'dev, 'runner, 'event) device; +type ('buffer_ptr, 'dev, 'runner, 'event) stream = + ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = { + device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref; runner : 'runner; merge_buffer : 'buffer_ptr buffer option ref; (** Depending on backend implementations, either the currently used merge buffer, or the one @@ -124,17 +160,18 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = { stream_id : int; (** An ID unique within the device. *) mutable allocated_buffer : 'buffer_ptr buffer option; updating_for : 'event Hashtbl.M(Tnode).t; - (* The completion event for updating (writing to) a node via this stream, if any. *) + (* The completion event for updating (writing to) a node via this stream, if any. *) mutable updating_for_merge_buffer : (Tnode.t * 'event) option; (** Like {!field-updating_for}, but for the merge buffer. *) - reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + reader_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; (** The streams, other than this stream, that most recently have been reading from a node in this stream's context, and the associated use completion events. The completed events are removed opportunistically. *) } [@@deriving sexp_of] -let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal +let equal_stream = equal_stream_ref type ('buffer_ptr, 'stream) context = { stream : 'stream;