Skip to content

Commit

Permalink
Fix: upcast constants that exceed fp16 cutoff config
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 24, 2024
1 parent e698ef3 commit 0159bfd
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 21 deletions.
18 changes: 7 additions & 11 deletions arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -693,17 +693,13 @@ let simplify_llc llc =
let result = Unop (op, v) in
if equal_float_t llv v then result else loop_float result
in
let check_constant =
match Utils.settings.check_half_prec_constants_cutoff with
| None -> fun _prec _c -> ()
| Some cutoff ->
fun tn c ->
if (Ops.is_fp16 @@ Lazy.force tn.Tn.prec) && Float.(abs c >= cutoff) then
raise
@@ Utils.User_error
("Constant " ^ Float.to_string c
^ " is too big for FP16 aka. half precision, risk of overflow; increase precision \
of tensor node " ^ Tn.debug_name tn)
let check_constant tn c =
if Tn.exceeds_fp16_cutoff tn c then
raise
@@ Utils.User_error
("Constant " ^ Float.to_string c
^ " is too big for FP16 aka. half precision, risk of overflow; increase precision of \
tensor node " ^ Tn.debug_name tn)
in
let rec check_proc llc =
let loop = check_proc in
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ let byte = Byte_prec Byte
let half = Half_prec Half
let single = Single_prec Single
let double = Double_prec Double
let is_fp16 = function Half_prec _ -> true | _ -> false
let is_up_to_fp16 = function Half_prec _ | Byte_prec _ -> true | _ -> false

let sexp_of_prec = function
| Void_prec -> Sexp.Atom "Void_prec"
Expand Down
18 changes: 18 additions & 0 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type t = {
(** Display information. It is better if the last element of the list is the most narrow or
alphanumeric, e.g. an identifier. *)
mutable delayed_prec_unsafe : delayed_prec;
(** Participates in the computation of {!field-prec}. *)
mutable memory_mode : (memory_mode * int) option;
mutable backend_info : Sexp.t;
mutable code_name : string option;
Expand Down Expand Up @@ -374,6 +375,23 @@ let update_prec ?only_if tn prec =
if cond old then prec else old))
| _ -> tn.delayed_prec_unsafe <- Specified prec

let exceeds_fp16_cutoff tn c =
match Utils.settings.check_half_prec_constants_cutoff with
| None -> false
| Some cutoff ->
(* Only force if needed. *)
Float.(abs c >= cutoff)
&&
let prec =
if Lazy.is_val tn.prec then Lazy.force tn.prec
else
match tn.delayed_prec_unsafe with
| Specified prec -> prec
| Default_spec prec -> Lazy.force prec
| Not_specified -> Lazy.force tn.prec
in
Ops.is_up_to_fp16 prec

include Comparator.Make (struct
type nonrec t = t

Expand Down
16 changes: 9 additions & 7 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,8 @@ let _mem_benchmarks =
~f:(fun batch_size ->
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map
[
(* "gccjit" ; *)
(* "cc"; *)
"cuda";
] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
Expand All @@ -242,6 +237,13 @@ let _suspended () =
(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
Stdio.stdout *)

let _suspended () =
[
classify_moons ~seed:7 ~on_device:true ~inlining_cutoff:0 ~num_streams:3 ~batch_size:240
~backend_name:"cc" ~value_prec:CDSL.half ~grad_prec:CDSL.half ();
]
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout

let benchmark benchmarks =
List.map benchmarks ~f:(fun bench -> bench ())
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
Expand Down
7 changes: 7 additions & 0 deletions lib/tensor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
| Some axis_label -> t ~output_axes:[ (axis_label, 1) ] ()
in
Tn.update_memory_mode t.value Effectively_constant 24;
Arrayjit.Ops.(
if Tn.exceeds_fp16_cutoff t.value c then
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
t

let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims
Expand Down Expand Up @@ -356,6 +359,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
()
in
Tn.update_memory_mode t.value Effectively_constant 24;
let max_abs = Array.fold values ~init:0. ~f:(fun acc v -> Float.(max acc @@ abs v)) in
Arrayjit.Ops.(
if Tn.exceeds_fp16_cutoff t.value max_abs then
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
t

let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced
Expand Down
3 changes: 2 additions & 1 deletion lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
| None -> !.init_lr *. ((2 *. !..steps) - !@step_n) /. !..steps
| Some schedule -> schedule ~batch_n ~step_n
in
Tn.update_prec ~only_if:Ops.is_fp16 learning_rate.value Ops.single;
(* Note: constants at default half-prec are automatically upcasted when they exceed
Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *)
set_hosted learning_rate.value;
let sgd = sgd_update ~learning_rate ~weight_decay update in
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
Expand Down
2 changes: 1 addition & 1 deletion todo.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is for tasks with a smaller granularity than issues, typically immediate tasks.
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow {cm:2024-11-24}
(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23}
(A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22}
(B) figure out why cuda backend parallelism slows down in later epochs

0 comments on commit 0159bfd

Please sign in to comment.