Skip to content

Commit

Permalink
Fix sexp_of_device/stream to break cyclicity
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 20, 2024
1 parent 72bf7ec commit f571d9e
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,42 +79,78 @@ module type Device_config = sig
val name : string
end

type ('buffer_ptr, 'dev, 'runner, 'event) device = {

type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable shared_merge_buffer : 'buffer_ptr buffer option;
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
shared_writer_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
host_reading_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
host_writing_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
}

and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
runner : 'runner;
merge_buffer : 'buffer_ptr buffer option ref;
mutable scheduled_merge_node : Tnode.t option;
stream_id : int;
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
reader_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
}

let sexp_of_device_ref _ _ _ _ device = [%sexp_of: string * int] ("ordinal", device.ordinal)
let sexp_of_stream_ref _ _ _ _ stream = [%sexp_of: string * int] ("stream_id", stream.stream_id)
let equal_stream_ref s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal

type ('buffer_ptr, 'dev, 'runner, 'event) device =
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable shared_merge_buffer : 'buffer_ptr buffer option;
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
or the one most recently scheduled. *)
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
and its readiness event. *)
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. *)
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t;
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
(** The stream owning a given node. This map can only grow. Currently, if the memory mode of a
node is inferred, only this stream will modify a cross-stream shared array. But memory
modes can also be set manually. *)
shared_writer_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) a
cross-stream-shared node, and the associated update completion event. The completed events
are removed opportunistically. *)
host_reading_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from a node's on-host array. The
completed events are removed opportunistically. *)
host_writing_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been writing to a node's on-host array. The completed
events are removed opportunistically. *)
}
[@@deriving sexp_of]

and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
device : ('buffer_ptr, 'dev, 'runner, 'event) device;
type ('buffer_ptr, 'dev, 'runner, 'event) stream =
('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
runner : 'runner;
merge_buffer : 'buffer_ptr buffer option ref;
(** Depending on backend implementations, either the currently used merge buffer, or the one
Expand All @@ -124,17 +160,18 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
(* The completion event for updating (writing to) a node via this stream, if any. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
reader_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
this stream's context, and the associated use completion events. The completed events are
removed opportunistically. *)
}
[@@deriving sexp_of]

let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal
let equal_stream = equal_stream_ref

type ('buffer_ptr, 'stream) context = {
stream : 'stream;
Expand Down

0 comments on commit f571d9e

Please sign in to comment.