From 0159bfd48d6f457daee0280d5d6ac39c228170ec Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Sun, 24 Nov 2024 21:23:22 +0100 Subject: [PATCH] Fix: upcast constants that exceed fp16 cutoff config --- arrayjit/lib/low_level.ml | 18 +++++++----------- arrayjit/lib/ops.ml | 2 +- arrayjit/lib/tnode.ml | 18 ++++++++++++++++++ bin/moons_benchmark.ml | 16 +++++++++------- lib/tensor.ml | 7 +++++++ lib/train.ml | 3 ++- todo.md | 2 +- 7 files changed, 45 insertions(+), 21 deletions(-) diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index 4c577211..96d0f080 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -693,17 +693,13 @@ let simplify_llc llc = let result = Unop (op, v) in if equal_float_t llv v then result else loop_float result in - let check_constant = - match Utils.settings.check_half_prec_constants_cutoff with - | None -> fun _prec _c -> () - | Some cutoff -> - fun tn c -> - if (Ops.is_fp16 @@ Lazy.force tn.Tn.prec) && Float.(abs c >= cutoff) then - raise - @@ Utils.User_error - ("Constant " ^ Float.to_string c - ^ " is too big for FP16 aka. half precision, risk of overflow; increase precision \ - of tensor node " ^ Tn.debug_name tn) + let check_constant tn c = + if Tn.exceeds_fp16_cutoff tn c then + raise + @@ Utils.User_error + ("Constant " ^ Float.to_string c + ^ " is too big for FP16 aka. half precision, risk of overflow; increase precision of \ + tensor node " ^ Tn.debug_name tn) in let rec check_proc llc = let loop = check_proc in diff --git a/arrayjit/lib/ops.ml b/arrayjit/lib/ops.ml index c51ca59a..15eeea5f 100644 --- a/arrayjit/lib/ops.ml +++ b/arrayjit/lib/ops.ml @@ -28,7 +28,7 @@ let byte = Byte_prec Byte let half = Half_prec Half let single = Single_prec Single let double = Double_prec Double -let is_fp16 = function Half_prec _ -> true | _ -> false +let is_up_to_fp16 = function Half_prec _ | Byte_prec _ -> true | _ -> false let sexp_of_prec = function | Void_prec -> Sexp.Atom "Void_prec" diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index 5fb7e7c4..fce4dee1 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -76,6 +76,7 @@ type t = { (** Display information. It is better if the last element of the list is the most narrow or alphanumeric, e.g. an identifier. *) mutable delayed_prec_unsafe : delayed_prec; + (** Participates in the computation of {!field-prec}. *) mutable memory_mode : (memory_mode * int) option; mutable backend_info : Sexp.t; mutable code_name : string option; @@ -374,6 +375,23 @@ let update_prec ?only_if tn prec = if cond old then prec else old)) | _ -> tn.delayed_prec_unsafe <- Specified prec +let exceeds_fp16_cutoff tn c = + match Utils.settings.check_half_prec_constants_cutoff with + | None -> false + | Some cutoff -> + (* Only force if needed. *) + Float.(abs c >= cutoff) + && + let prec = + if Lazy.is_val tn.prec then Lazy.force tn.prec + else + match tn.delayed_prec_unsafe with + | Specified prec -> prec + | Default_spec prec -> Lazy.force prec + | Not_specified -> Lazy.force tn.prec + in + Ops.is_up_to_fp16 prec + include Comparator.Make (struct type nonrec t = t diff --git a/bin/moons_benchmark.ml b/bin/moons_benchmark.ml index 29a83a89..3d8daf9e 100644 --- a/bin/moons_benchmark.ml +++ b/bin/moons_benchmark.ml @@ -214,13 +214,8 @@ let _mem_benchmarks = ~f:(fun batch_size -> List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff -> List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed -> - List.concat_map - [ - (* "gccjit" ; *) - (* "cc"; *) - "cuda"; - ] ~f:(fun backend_name -> - List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ] + List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name -> + List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ] ~f:(fun value_prec -> [ classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams @@ -242,6 +237,13 @@ let _suspended () = (* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout *) +let _suspended () = + [ + classify_moons ~seed:7 ~on_device:true ~inlining_cutoff:0 ~num_streams:3 ~batch_size:240 + ~backend_name:"cc" ~value_prec:CDSL.half ~grad_prec:CDSL.half (); + ] + |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout + let benchmark benchmarks = List.map benchmarks ~f:(fun bench -> bench ()) |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout diff --git a/lib/tensor.ml b/lib/tensor.ml index fed66b4e..256a20f9 100644 --- a/lib/tensor.ml +++ b/lib/tensor.ml @@ -314,6 +314,9 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c = | Some axis_label -> t ~output_axes:[ (axis_label, 1) ] () in Tn.update_memory_mode t.value Effectively_constant 24; + Arrayjit.Ops.( + if Tn.exceeds_fp16_cutoff t.value c then + Tn.update_prec ~only_if:is_up_to_fp16 t.value single); t let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims @@ -356,6 +359,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ? () in Tn.update_memory_mode t.value Effectively_constant 24; + let max_abs = Array.fold values ~init:0. ~f:(fun acc v -> Float.(max acc @@ abs v)) in + Arrayjit.Ops.( + if Tn.exceeds_fp16_cutoff t.value max_abs then + Tn.update_prec ~only_if:is_up_to_fp16 t.value single); t let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced diff --git a/lib/train.ml b/lib/train.ml index b27f7ab2..f80ffe82 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -487,7 +487,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init | None -> !.init_lr *. ((2 *. !..steps) - !@step_n) /. !..steps | Some schedule -> schedule ~batch_n ~step_n in - Tn.update_prec ~only_if:Ops.is_fp16 learning_rate.value Ops.single; + (* Note: constants at default half-prec are automatically upcasted when they exceed + Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *) set_hosted learning_rate.value; let sgd = sgd_update ~learning_rate ~weight_decay update in let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in diff --git a/todo.md b/todo.md index 511e9ad8..da8bfd1c 100644 --- a/todo.md +++ b/todo.md @@ -1,5 +1,5 @@ # This file is for tasks with a smaller granularity than issues, typically immediate tasks. -(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow +(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow {cm:2024-11-24} (B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23} (A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22} (B) figure out why cuda backend parallelism slows down in later epochs \ No newline at end of file