Skip to content

Commit

Permalink
Wasm runtime: add support for float16 bigarrays
Browse files Browse the repository at this point in the history
Implemented using Float16Array typed arrays, or Float32Array as a fallback.
  • Loading branch information
vouillon committed Jan 21, 2025
1 parent 8385a5e commit 18a66e2
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 27 deletions.
187 changes: 160 additions & 27 deletions runtime/wasm/bigarray.wat
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
(func $ta_get_f64 (param (ref extern)) (param i32) (result f64)))
(import "bindings" "ta_get_f32"
(func $ta_get_f32 (param (ref extern)) (param i32) (result f64)))
(import "bindings" "ta_get_f16"
(func $ta_get_f16 (param (ref extern)) (param i32) (result f64)))
(import "bindings" "ta_get_i32"
(func $ta_get_i32 (param (ref extern)) (param i32) (result i32)))
(import "bindings" "ta_get_i16"
Expand All @@ -46,6 +48,8 @@
(func $ta_set_f64 (param (ref extern)) (param i32) (param f64)))
(import "bindings" "ta_set_f32"
(func $ta_set_f32 (param (ref extern)) (param i32) (param f64)))
(import "bindings" "ta_set_f16"
(func $ta_set_f16 (param (ref extern)) (param i32) (param f64)))
(import "bindings" "ta_set_i32"
(func $ta_set_i32 (param (ref extern)) (param i32) (param i32)))
(import "bindings" "ta_set_i16"
Expand Down Expand Up @@ -104,6 +108,8 @@
(func $caml_hash_mix_double (param i32) (param f64) (result i32)))
(import "hash" "caml_hash_mix_float"
(func $caml_hash_mix_float (param i32) (param f32) (result i32)))
(import "hash" "caml_hash_mix_float16"
(func $caml_hash_mix_float16 (param i32) (param i32) (result i32)))
(import "marshal" "caml_serialize_int_1"
(func $caml_serialize_int_1 (param (ref eq)) (param i32)))
(import "marshal" "caml_serialize_int_2"
Expand Down Expand Up @@ -177,6 +183,59 @@
(field $ba_kind i8) ;; kind
(field $ba_layout i8)))) ;; layout

(func $double_to_float16 (param $f f64) (result i32)
(local $x i32) (local $sign i32) (local $o i32)
(local.set $x (i32.reinterpret_f32 (f32.demote_f64 (local.get $f))))
(local.set $sign (i32.and (local.get $x) (i32.const 0x80000000)))
(local.set $x (i32.xor (local.get $x) (local.get $sign)))
(if (i32.ge_u (local.get $x) (i32.const 0x47800000))
(then
(local.set $o
(select
(i32.const 0x7E00) ;; NaN
(i32.const 0x7C00) ;; infinity
(i32.gt_u (local.get $x) (i32.const 0x7f800000)))))
(else
(if (i32.lt_u (local.get $x) (i32.const 0x38800000))
(then
(local.set $o
(i32.sub
(i32.reinterpret_f32
(f32.add (f32.reinterpret_i32 (local.get $x))
(f32.const 0.5)))
(i32.const 0x3f000000))))
(else
(local.set $o
(i32.shr_u
(i32.add (i32.add (local.get $x) (i32.const 0xC8000FFF))
(i32.and (i32.shr_u (local.get $x) (i32.const 13))
(i32.const 1)))
(i32.const 13)))))))
(i32.or (local.get $o) (i32.shr_u (local.get $sign) (i32.const 16))))

(func $float16_to_double (param $d i32) (result f64)
(local $f f32)
(local.set $f
(f32.mul
(f32.reinterpret_i32
;; exponent and mantissa
(i32.shl (i32.and (local.get $d) (i32.const 0x7FFF))
(i32.const 13)))
(f32.const 0x1p+112)))
(if (f32.ge (local.get $f) (f32.const 65536))
(then
;; NaN / infinity
(local.set $f
(f32.reinterpret_i32
(i32.or (i32.reinterpret_f32 (local.get $f))
(i32.const 0x7f800000))))))
(f64.promote_f32
(f32.reinterpret_i32
(i32.or (i32.reinterpret_f32 (local.get $f))
;; sign bit
(i32.shl (i32.and (local.get $d) (i32.const 0x8000))
(i32.const 16))))))

(func $bigarray_hash (param (ref eq)) (result i32)
(local $b (ref $bigarray))
(local $h i32) (local $len i32) (local $i i32) (local $w i32)
Expand All @@ -192,10 +251,25 @@
(block $uint16
(block $int32
(block $int64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int32 $int32
$float32 $float64 $uint8
(struct.get $bigarray $ba_kind (local.get $b))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int32 $int32
$float32 $float64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $b))))
;; float16
(if (i32.gt_u (local.get $len) (i32.const 128))
(then (local.set $len (i32.const 128))))
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(local.set $h
(call $caml_hash_mix_float16 (local.get $h)
(call $double_to_float16
(call $ta_get_f16
(local.get $data) (local.get $i)))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(return (local.get $h)))
;; int64
(if (i32.gt_u (local.get $len) (i32.const 32))
(then (local.set $len (i32.const 32))))
Expand Down Expand Up @@ -424,10 +498,22 @@
(block $int32
(block $int
(block $int64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $int
$float32 $float64 $uint8
(struct.get $bigarray $ba_kind (local.get $b))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $int
$float32 $float64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $b))))
;; float16
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(call $caml_serialize_int_2 (local.get $s)
(call $double_to_float16
(call $ta_get_f16
(local.get $data) (local.get $i))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(br $done))
;; int64
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
Expand Down Expand Up @@ -568,10 +654,21 @@
(block $int32
(block $int
(block $int64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $int
$float32 $float64 $uint8
(struct.get $bigarray $ba_kind (local.get $b))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $int
$float32 $float64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $b))))
;; float16
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(call $ta_set_f16 (local.get $data) (local.get $i)
(call $float16_to_double
(call $caml_deserialize_uint_2 (local.get $s))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(br $done))
;; int64
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
Expand Down Expand Up @@ -763,7 +860,7 @@
(call $caml_invalid_argument
(array.new_data $string $ta_unsupported_kind
(i32.const 0) (i32.const 41)))))
(if (i32.eq (local.get $kind) (i32.const 13)) ;; Uint8ClampedArray
(if (i32.eq (local.get $kind) (i32.const 14)) ;; Uint8ClampedArray
(then (local.set $kind (i32.const 3))))
(local.set $len (call $ta_length (local.get $data)))
(if (i32.lt_s (local.get $len) (i32.const 0))
Expand Down Expand Up @@ -801,10 +898,16 @@
(block $nativeint
(block $complex32
(block $complex64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $nativeint
$complex32 $complex64 $uint8
(struct.get $bigarray $ba_kind (local.get $ba))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $nativeint
$complex32 $complex64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $ba))))
;; float16
(return
(struct.new $float
(call $ta_get_f16
(local.get $data) (local.get $i)))))
;; complex64
(local.set $i (i32.shl (local.get $i) (i32.const 1)))
(return
Expand Down Expand Up @@ -876,10 +979,16 @@
(block $nativeint
(block $complex32
(block $complex64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $nativeint
$complex32 $complex64 $uint8
(struct.get $bigarray $ba_kind (local.get $ba))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int $nativeint
$complex32 $complex64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $ba))))
;; float16
(call $ta_set_f16 (local.get $data) (local.get $i)
(struct.get $float 0
(ref.cast (ref $float) (local.get $v))))
(return))
;; complex64
(local.set $i (i32.shl (local.get $i) (i32.const 1)))
(local.set $b (ref.cast (ref $float_array) (local.get $v)))
Expand Down Expand Up @@ -1481,8 +1590,8 @@
(block $complex32
(block $complex64
(br_table $float $float $int $int $int $int $int32 $int64 $int
$int32 $complex32 $complex64 $int
(struct.get $bigarray $ba_kind (local.get $ba))))
$int32 $complex32 $complex64 $int $float
(struct.get $bigarray $ba_kind (local.get $ba))))
;; complex64
(local.set $len (call $ta_length (local.get $data)))
(local.set $b (ref.cast (ref $float_array) (local.get $v)))
Expand Down Expand Up @@ -1750,10 +1859,34 @@
(block $uint16
(block $int32
(block $int64
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int32 $int32
$float32 $float64 $uint8
(struct.get $bigarray $ba_kind (local.get $b1))))
(block $float16
(br_table $float32 $float64 $int8 $uint8 $int16 $uint16
$int32 $int64 $int32 $int32
$float32 $float64 $uint8 $float16
(struct.get $bigarray $ba_kind (local.get $b1))))
;; float16
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
(then
(local.set $f1
(call $ta_get_f16 (local.get $d1) (local.get $i)))
(local.set $f2
(call $ta_get_f16 (local.get $d2) (local.get $i)))
(if (f64.lt (local.get $f1) (local.get $f2))
(then (return (i32.const -1))))
(if (f64.gt (local.get $f1) (local.get $f2))
(then (return (i32.const 1))))
(if (f64.ne (local.get $f1) (local.get $f2))
(then
(if (i32.eqz (local.get $total))
(then (return (global.get $unordered))))
(if (f64.eq (local.get $f1) (local.get $f1))
(then (return (i32.const 1))))
(if (f64.eq (local.get $f2) (local.get $f2))
(then (return (i32.const -1))))))
(local.set $i (i32.add (local.get $i) (i32.const 1)))
(br $loop))))
(return (i32.const 0)))
;; int64
(loop $loop
(if (i32.lt_u (local.get $i) (local.get $len))
Expand Down
11 changes: 11 additions & 0 deletions runtime/wasm/hash.wat
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@
(then (local.set $i (i32.const 0))))
(return_call $caml_hash_mix_int (local.get $h) (local.get $i)))

(func (export "caml_hash_mix_float16")
(param $h i32) (param $i i32) (result i32)
(if (i32.eq (i32.and (local.get $i) (i32.const 0x7c00))
(i32.const 0x7c00))
(then
(if (i32.and (local.get $i) (i32.const 0x03ff))
(then (local.set $i (i32.const 0x7FC01))))))
(if (i32.eq (local.get $i) (i32.const 0x8000))
(then (local.set $i (i32.const 0))))
(return_call $caml_hash_mix_int (local.get $h) (local.get $i)))

(func $caml_hash_mix_string (export "caml_hash_mix_string")
(param $h i32) (param $s (ref $string)) (result i32)
(local $i i32) (local $len i32) (local $w i32)
Expand Down
3 changes: 3 additions & 0 deletions runtime/wasm/runtime.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
Float32Array,
Float64Array,
Uint8Array,
globalThis.Float16Array || Float32Array,
Uint8ClampedArray,
];

Expand Down Expand Up @@ -188,6 +189,7 @@
ta_length: (a) => a.length,
ta_get_f64: (a, i) => a[i],
ta_get_f32: (a, i) => a[i],
ta_get_f16: (a, i) => a[i],
ta_get_i32: (a, i) => a[i],
ta_get_i16: (a, i) => a[i],
ta_get_ui16: (a, i) => a[i],
Expand All @@ -198,6 +200,7 @@
a[i] | (a[i + 1] << 8) | (a[i + 2] << 16) | (a[i + 3] << 24),
ta_set_f64: (a, i, v) => (a[i] = v),
ta_set_f32: (a, i, v) => (a[i] = v),
ta_set_f16: (a, i, v) => (a[i] = v),
ta_set_i32: (a, i, v) => (a[i] = v),
ta_set_i16: (a, i, v) => (a[i] = v),
ta_set_ui16: (a, i, v) => (a[i] = v),
Expand Down
1 change: 1 addition & 0 deletions tools/node_wrapper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ let extra_args_for_wasoo =
[ "--experimental-wasm-imported-strings"
; "--experimental-wasm-stack-switching"
; "--stack-size=10000"
; "--js-float16array"
]

let env = Unix.environment ()
Expand Down

0 comments on commit 18a66e2

Please sign in to comment.