Skip to content

Commit

Permalink
The cuda backend is now a generative functor; Cu.init called at mod…
Browse files Browse the repository at this point in the history
…ule initialization
  • Loading branch information
lukstafi committed Dec 16, 2024
1 parent 2af41be commit 0b7ae75
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 424 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- Built per-tensor-node stream-to-stream synchronization into copying functions.
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.
- Simplifications: no more explicit compilation postponing; no more hard-coded pointers (all non-local arrays are passed by parameter).
- Fresh backends are now fresh modules to structurally prevent any potential cache leaking.

### Fixed

Expand Down
1 change: 0 additions & 1 deletion arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ struct
{
dev;
ordinal;
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_stream = Hashtbl.create (module Tnode);
shared_writer_streams = Hashtbl.create (module Tnode);
Expand Down
2 changes: 0 additions & 2 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ end
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
shared_writer_streams :
Expand Down Expand Up @@ -112,7 +111,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. *)
Expand Down
5 changes: 3 additions & 2 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,6 @@ let finalize (type buffer_ptr dev runner event)
Option.iter Backend.free_buffer ~f:(fun mem_free ->
if
Atomic.compare_and_set ctx.finalized false true
&& (not @@ Atomic.get ctx.stream.device.released)
then (
Backend.await ctx.stream;
Map.iteri ctx.ctx_arrays ~f:(fun ~key ~data ->
Expand All @@ -475,6 +474,8 @@ let%track5_sexp fresh_backend ?backend_name () =
Stdlib.Gc.full_major ();
(* TODO: is running again needed to give time to weak arrays to become empty? *)
Stdlib.Gc.full_major ();
(* Note: we invoke functors from within fresh_backend to fully isolate backends from distinct
calls to fresh_backend. *)
match
Option.value_or_thunk backend_name ~default:(fun () ->
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
Expand All @@ -486,5 +487,5 @@ let%track5_sexp fresh_backend ?backend_name () =
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
| "sync_gccjit" ->
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend) : Backend)
| "cuda" -> (module Raise_backend ((Cuda_backend : Lowered_backend)) : Backend)
| "cuda" -> (module Raise_backend ((Cuda_backend.Fresh () : Lowered_backend)) : Backend)
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]
6 changes: 3 additions & 3 deletions arrayjit/lib/backends.mli
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ val finalize :
Note: this type will get simpler with modular explicits. *)

val fresh_backend :
?backend_name:string -> unit -> (module Backend_intf.Backend)
val fresh_backend : ?backend_name:string -> unit -> (module Backend_intf.Backend)
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
[backend] setting. *)
[backend] setting. It should be safe to call {!Tensor.unsafe_reinitialize} before
[fresh_backend]. *)
Loading

0 comments on commit 0b7ae75

Please sign in to comment.