Skip to content

Commit

Permalink
Fixes #295: always create new modules for fresh_backend to never le…
Browse files Browse the repository at this point in the history
…ak any caches
  • Loading branch information
lukstafi committed Dec 16, 2024
1 parent 77b3395 commit 2af41be
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 137 deletions.
65 changes: 17 additions & 48 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,6 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
(context, Some r))
end

module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))

module Make_device_backend_from_lowered
(Add_scheduler : functor
(Impl : For_add_scheduler)
Expand All @@ -455,36 +453,6 @@ struct
include Backend_device
end

module Cc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend)
module Gcc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend)
module Cc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend)
module Gcc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend)

let%track5_sexp reinitialize (module Backend : Backend) config =
if not @@ Backend.is_initialized () then Backend.initialize config
else (
[%log "reinitialize: cleanup devices"];
for ordinal = 0 to Backend.num_devices () - 1 do
Backend.(sync_device @@ get_device ~ordinal)
done;
[%log "reinitialize: efore full_major"];
Stdlib.Gc.full_major ();
[%log "reinitialize: cleanup devices 2"];
(* TODO: does this do anything? *)
for ordinal = 0 to Backend.num_devices () - 1 do
Backend.(sync_device @@ get_device ~ordinal)
done;
[%log "reinitialize: after cleanup 2"];
(* This ensures cleanliness of the streams weak arrays. *)
Stdlib.Gc.full_major ();
[%log "reinitialize: after full_major 2"];
for ordinal = 0 to Backend.num_devices () - 1 do
let device = Backend.get_device ~ordinal in
Utils.weak_iter device.streams ~f:(fun _stream -> assert false)
done;
[%log "reinitialize: after checking devices"];
Backend.initialize config)

let finalize (type buffer_ptr dev runner event)
(module Backend : Backend
with type buffer_ptr = buffer_ptr
Expand All @@ -503,19 +471,20 @@ let finalize (type buffer_ptr dev runner event)
&& not (Hashtbl.mem ctx.stream.device.cross_stream_candidates key)
then mem_free ctx.stream data)))

let%track5_sexp fresh_backend ?backend_name ?(config = Only_devices_parallel) () =
let backend =
match
Option.value_or_thunk backend_name ~default:(fun () ->
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
|> String.lowercase
with
| "cc" -> (module Cc_multicore : Backend)
| "gccjit" -> (module Gcc_multicore : Backend)
| "sync_cc" -> (module Cc_sync : Backend)
| "sync_gccjit" -> (module Gcc_sync : Backend)
| "cuda" -> (module Cuda_backend : Backend)
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]
in
reinitialize backend config;
backend
let%track5_sexp fresh_backend ?backend_name () =
Stdlib.Gc.full_major ();
(* TODO: is running again needed to give time to weak arrays to become empty? *)
Stdlib.Gc.full_major ();
match
Option.value_or_thunk backend_name ~default:(fun () ->
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
|> String.lowercase
with
| "cc" -> (module Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend) : Backend)
| "gccjit" ->
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend) : Backend)
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
| "sync_gccjit" ->
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend) : Backend)
| "cuda" -> (module Raise_backend ((Cuda_backend : Lowered_backend)) : Backend)
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]
9 changes: 3 additions & 6 deletions arrayjit/lib/backends.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

open Base

val reinitialize : (module Backend_intf.Backend) -> Backend_intf.config -> unit
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)

val finalize :
'buffer_ptr 'dev 'runner 'event.
(module Backend_intf.Backend
Expand All @@ -21,6 +18,6 @@ val finalize :
Note: this type will get simpler with modular explicits. *)

val fresh_backend :
?backend_name:string -> ?config:Backend_intf.config -> unit -> (module Backend_intf.Backend)
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
the global [backend] setting. See {!reinitialize}. *)
?backend_name:string -> unit -> (module Backend_intf.Backend)
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
[backend] setting. *)
2 changes: 2 additions & 0 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ let is_initialized, initialize =
((fun () -> !initialized), init)

let num_devices = Cu.Device.get_count

(* TODO: this doesn't need to be weak array. *)
let devices = ref @@ Stdlib.Weak.create 0

(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)
Expand Down
166 changes: 83 additions & 83 deletions test/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,22 @@ let%expect_test "Simple gradients hosted" =
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
{|
#12 *._l Host&dev/41
<not-in-yet>
#13 grad_*._l Host&dev/41
<not-in-yet>
#8 +_d Host&dev/41 │#10 f Host-non-const/24
<not-in-yet><not-in-yet>
#9 grad_+_d Host&dev/41 │#11 grad_f Host&dev/41
<not-in-yet><not-in-yet>
#4 *._e Host&dev/41 │#6 c Host-non-const/24
<not-in-yet><not-in-yet>
#5 grad_*._e Host&dev/41 │#7 grad_c Host&dev/41
<not-in-yet><not-in-yet>
#0 a Host-non-const/24│#2 b Host-non-const/24
<not-in-yet><not-in-yet>
#1 grad_a Host&dev/41 │#3 grad_b Host&dev/41
<not-in-yet><not-in-yet>
#12 *._l Host&stream/41
<not-in-yet>
#13 grad_*._l Host&stream/41
<not-in-yet>
#8 +_d Host&stream/41 │#10 f Host&shared/39
<not-in-yet> <not-in-yet>
#9 grad_+_d Host&stream/41 │#11 grad_f Host&stream/41
<not-in-yet> <not-in-yet>
#4 *._e Host&stream/41 │#6 c Host&shared/39
<not-in-yet> <not-in-yet>
#5 grad_*._e Host&stream/41 │#7 grad_c Host&stream/41
<not-in-yet> <not-in-yet>
#0 a Host&shared/39 │#2 b Host&shared/39
<not-in-yet> <not-in-yet>
#1 grad_a Host&stream/41│#3 grad_b Host&stream/41
<not-in-yet> <not-in-yet>
|}];
(* Do not update the params: all values and gradients will be at initial points, which are
specified in the tensor in the brackets. *)
Expand Down Expand Up @@ -411,45 +411,45 @@ let%expect_test "Simple gradients virtual" =
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
{|
#12 *._l Host&dev/41
<not-in-yet>
#13 grad_*._l Virt/40
<not-in-yet>
#8 +_d Local/50 │#10 f Host-non-const/24
<not-in-yet><not-in-yet>
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
<not-in-yet><not-in-yet>
#4 *._e Virt/152 │#6 c Host-non-const/24
<not-in-yet><not-in-yet>
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
<not-in-yet><not-in-yet>
#0 a Host-non-const/24│#2 b Host-non-const/24
<not-in-yet><not-in-yet>
#1 grad_a On-dev/50 │#3 grad_b On-dev/50
<not-in-yet><not-in-yet>
#12 *._l Host&stream/41
<not-in-yet>
#13 grad_*._l Virt/40
<not-in-yet>
#8 +_d Local/46 │#10 f Host&shared/39
<not-in-yet> <not-in-yet>
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
<not-in-yet> <not-in-yet>
#4 *._e Virt/152 │#6 c Host&shared/39
<not-in-yet> <not-in-yet>
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
<not-in-yet> <not-in-yet>
#0 a Host&shared/39 │#2 b Host&shared/39
<not-in-yet> <not-in-yet>
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
<not-in-yet> <not-in-yet>
|}];
(* Do not update the params: all values and gradients will be at initial points, which are
specified in the tensor in the brackets. *)
Train.sync_run backend grad_routine l;
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
{|
#12 *._l
-8.00e+0
#13 grad_*._l Virt/40
<void>
#8 +_d Local/50 │#10 f
<void>-2.00e+0
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
<void><void>
#4 *._e Virt/152 │#6 c │
<void>1.00e+1
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
<void><void>
#0 a │#2 b
2.00e+0-3.00e+0
#1 grad_a On-dev/50│#3 grad_b On-dev/50
<void><void>
#12 *._l
-8.00e+0
#13 grad_*._l Virt/40
<void>
#8 +_d Local/46 │#10 f
<void> -2.00e+0
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
<void> <void>
#4 *._e Virt/152 │#6 c
<void> 1.00e+1
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
<void> <void>
#0 a │#2 b
2.00e+0 -3.00e+0
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
<void> <void>
|}];
(* Only now compile the SGD update. *)
let sgd_routine = Train.to_routine (module Backend) grad_routine.context IDX.empty sgd in
Expand All @@ -460,45 +460,45 @@ let%expect_test "Simple gradients virtual" =
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
{|
#12 *._l
-8.00e+0
#13 grad_*._l Virt/40
<void>
#8 +_d Local/50 │#10 f
<void>-2.40e+0
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
<void><void>
#4 *._e Virt/152 │#6 c │
<void>1.02e+1
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
<void><void>
#0 a │#2 b
1.40e+0-2.60e+0
#1 grad_a On-dev/50│#3 grad_b On-dev/50
<void><void>
#12 *._l
-8.00e+0
#13 grad_*._l Virt/40
<void>
#8 +_d Local/46 │#10 f
<void> -2.40e+0
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
<void> <void>
#4 *._e Virt/152 │#6 c
<void> 1.02e+1
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
<void> <void>
#0 a │#2 b
1.40e+0 -2.60e+0
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
<void> <void>
|}];
(* Now the params will remain as above, but both param gradients and the values and gradients of
other nodes will change thanks to the forward and backward passes. *)
Train.sync_run backend grad_routine l;
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
{|
#12 *._l
-1.57e+1
#13 grad_*._l Virt/40
<void>
#8 +_d Local/50 │#10 f
<void>-2.40e+0
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
<void><void>
#4 *._e Virt/152 │#6 c │
<void>1.02e+1
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
<void><void>
#0 a │#2 b
1.40e+0-2.60e+0
#1 grad_a On-dev/50│#3 grad_b On-dev/50
<void><void>
#12 *._l
-1.57e+1
#13 grad_*._l Virt/40
<void>
#8 +_d Local/46 │#10 f
<void> -2.40e+0
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
<void> <void>
#4 *._e Virt/152 │#6 c
<void> 1.02e+1
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
<void> <void>
#0 a │#2 b
1.40e+0 -2.60e+0
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
<void> <void>
|}]

let%expect_test "tanh plot" =
Expand Down Expand Up @@ -565,12 +565,12 @@ let%expect_test "2D neuron virtual" =
7.00e-1
#9 grad_+_v Virt/40
<void>
#6 * Local/50 │#0 b
#6 * Local/46 │#0 b
<void>6.70e+0
#7 grad_* Virt/40 │#1 grad_b Local/50
#7 grad_* Virt/40 │#1 grad_b Local/46
<void><void>
#2 w │#4 x │
-3.00e+0 1.00e+02.00e+0 0.00e+0
#3 grad_w Local/50 │#5 grad_x Local/50
#3 grad_w Local/46 │#5 grad_x Local/46
<void><void>
|}]

0 comments on commit 2af41be

Please sign in to comment.