Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataclass] Fix dataclass mutable defaults error #1067

Merged
merged 1 commit into from
Nov 6, 2023
Merged

[dataclass] Fix dataclass mutable defaults error #1067

merged 1 commit into from
Nov 6, 2023

Conversation

carlosgmartin
Copy link
Contributor

@sotetsuk
Copy link
Owner

sotetsuk commented Oct 16, 2023

Thank you for sending the PR! I've confirmed that it works with 3.11.
However, when checking the code compiled with jit in the following code, the number of lines significantly increases from line 57 to 75, and I'm concerned about how these will affect at runtime.

Also, pgx/_src/struct.py is a fork from Flax, and we'd prefer to make the minimum changes to it.

Support for 3.11 is not a high priority for us, so while this PR is a promising solution, we haven't yet settled on a single solution. Therefore, we'd like to put this PR on hold for now.

import jax
import jax.numpy as jnp
import pgx

env = pgx.make("tic_tac_toe")
print(jax.make_jaxpr(env.init)(jax.random.PRNGKey(0)))
Python3.9.18 + Pgx v1.4.0: 57 lines
{ lambda a:i8[9] b:f32[2] c:bool[9]; d:u32[2]. let
    e:key<fry>[] = random_wrap[impl=fry] d
    f:key<fry>[2] = random_split[count=2] e
    g:u32[2,2] = random_unwrap f
    h:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] g
    i:u32[2] = squeeze[dimensions=(0,)] h
    j:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] g
    k:u32[2] = squeeze[dimensions=(0,)] j
    l:key<fry>[] = random_wrap[impl=fry] k
    m:key<fry>[2] = random_split[count=2] l
    n:u32[2,2] = random_unwrap m
    o:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] n
    _:u32[2] = squeeze[dimensions=(0,)] o
    p:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] n
    q:u32[2] = squeeze[dimensions=(0,)] p
    r:key<fry>[] = random_wrap[impl=fry] q
    s:u32[] = random_bits[bit_width=32 shape=()] r
    t:u32[] = shift_right_logical s 9
    u:u32[] = or t 1065353216
    v:f32[] = bitcast_convert_type[new_dtype=float32] u
    w:f32[] = sub v 1.0
    x:f32[] = sub 1.0 0.0
    y:f32[] = mul w x
    z:f32[] = add y 0.0
    ba:f32[] = reshape[dimensions=None new_sizes=()] z
    bb:f32[] = max 0.0 ba
    bc:bool[] = lt bb 0.5
    bd:i8[] = convert_element_type[new_dtype=int8 weak_type=False] bc
    be:bool[] = eq bd bd
    bf:i32[] = convert_element_type[new_dtype=int32 weak_type=False] be
    bg:i8[2] = cond[
      branches=(
        { lambda ; . let
            bh:i8[] = sub 1 0
            bi:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bh
            bj:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
            bk:i8[2] = concatenate[dimension=0] bi bj
          in (bk,) }
        { lambda ; . let
            bl:i8[] = sub 1 0
            bm:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
            bn:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bl
            bo:i8[2] = concatenate[dimension=0] bm bn
          in (bo,) }
      )
      linear=()
    ] bf
    bp:i8[1,9] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 9)] a
    bq:i8[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] bg
    br:bool[2,9] = eq bp bq
    bs:bool[2,3,3] = reshape[dimensions=None new_sizes=(2, 3, 3)] br
    bt:bool[2,3,3,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(2, 3, 3, 1)
    ] bs
    bu:bool[3,3,2] = reshape[dimensions=(1, 2, 0, 3) new_sizes=(3, 3, 2)] bt
  in (bd, bu, b, False, False, c, i, 0, 0, a) }
Python3.11.6 + #1067: 75 lines
{ lambda a:i8[] b:i8[9] c:f32[2] d:bool[] e:bool[9] f:i32[]; g:u32[2]. let
    h:key<fry>[] = random_wrap[impl=fry] g
    i:key<fry>[2] = random_split[shape=(2,)] h
    j:u32[2,2] = random_unwrap i
    k:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] j
    l:u32[2] = squeeze[dimensions=(0,)] k
    m:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] j
    n:u32[2] = squeeze[dimensions=(0,)] m
    o:key<fry>[] = random_wrap[impl=fry] n
    p:key<fry>[2] = random_split[shape=(2,)] o
    q:u32[2,2] = random_unwrap p
    r:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] q
    _:u32[2] = squeeze[dimensions=(0,)] r
    s:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] q
    t:u32[2] = squeeze[dimensions=(0,)] s
    u:key<fry>[] = random_wrap[impl=fry] t
    v:bool[] = pjit[
      jaxpr={ lambda ; w:key<fry>[] x:f32[]. let
          y:f32[] = pjit[
            jaxpr={ lambda ; z:key<fry>[] ba:f32[] bb:f32[]. let
                bc:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] ba
                bd:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                be:u32[] = random_bits[bit_width=32 shape=()] z
                bf:u32[] = shift_right_logical be 9
                bg:u32[] = or bf 1065353216
                bh:f32[] = bitcast_convert_type[new_dtype=float32] bg
                bi:f32[] = sub bh 1.0
                bj:f32[] = sub bd bc
                bk:f32[] = mul bi bj
                bl:f32[] = add bk bc
                bm:f32[] = reshape[dimensions=None new_sizes=()] bl
                bn:f32[] = max bc bm
              in (bn,) }
            name=_uniform
          ] w 0.0 1.0
          bo:bool[] = lt y x
        in (bo,) }
      name=_bernoulli
    ] u 0.5
    bp:i8[] = convert_element_type[new_dtype=int8 weak_type=False] v
    bq:bool[] = eq bp bp
    br:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq
    bs:i8[2] = cond[
      branches=(
        { lambda ; bt_:i8[] bu:i8[]. let
            bv:i8[] = sub 1 bu
            bw:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bv
            bx:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu
            by:i8[2] = concatenate[dimension=0] bw bx
          in (by,) }
        { lambda ; bz:i8[] ca_:i8[]. let
            cb:i8[] = sub 1 bz
            cc:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bz
            cd:i8[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cb
            ce:i8[2] = concatenate[dimension=0] cc cd
          in (ce,) }
      )
      linear=(False, False)
    ] br a a
    cf:i8[1,9] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 9)] b
    cg:i8[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] bs
    ch:bool[2,9] = eq cf cg
    ci:bool[2,3,3] = reshape[dimensions=None new_sizes=(2, 3, 3)] ch
    cj:bool[2,3,3,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(2, 3, 3, 1)
    ] ci
    ck:bool[3,3,2] = reshape[dimensions=(1, 2, 0, 3) new_sizes=(3, 3, 2)] cj
  in (bp, ck, c, d, d, e, l, f, a, b) }

Both on M1 mac

@sotetsuk
Copy link
Owner

sotetsuk commented Nov 6, 2023

I found that the above inspections are wrong and this PR does not change the complexity of jaxpr.
So I'll merge this PR after CI is finished. Thank you for your contribution! 🙏
New release will be available in a few days.

Python 3.9.16 @ 8764592 (latest)

Package            Version
------------------ -------
importlib-metadata 6.8.0
jax                0.4.20
jaxlib             0.4.20
ml-dtypes          0.3.1
numpy              1.26.1
opt-einsum         3.3.0
pip                23.3.1
scipy              1.11.3
setuptools         65.6.3
svgwrite           1.4.3
typing_extensions  4.8.0
zipp               3.17.0
{ lambda a:i32[] b:i32[9] c:f32[2] d:bool[] e:bool[9] f:i32[]; g:u32[2]. let
    h:key<fry>[] = random_wrap[impl=fry] g
    i:key<fry>[2] = random_split[shape=(2,)] h
    j:u32[2,2] = random_unwrap i
    k:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] j
    l:u32[2] = squeeze[dimensions=(0,)] k
    m:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] j
    n:u32[2] = squeeze[dimensions=(0,)] m
    o:key<fry>[] = random_wrap[impl=fry] n
    p:key<fry>[2] = random_split[shape=(2,)] o
    q:u32[2,2] = random_unwrap p
    r:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] q
    _:u32[2] = squeeze[dimensions=(0,)] r
    s:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] q
    t:u32[2] = squeeze[dimensions=(0,)] s
    u:key<fry>[] = random_wrap[impl=fry] t
    v:bool[] = pjit[
      jaxpr={ lambda ; w:key<fry>[] x:f32[]. let
          y:f32[] = pjit[
            jaxpr={ lambda ; z:key<fry>[] ba:f32[] bb:f32[]. let
                bc:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] ba
                bd:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                be:u32[] = random_bits[bit_width=32 shape=()] z
                bf:u32[] = shift_right_logical be 9
                bg:u32[] = or bf 1065353216
                bh:f32[] = bitcast_convert_type[new_dtype=float32] bg
                bi:f32[] = sub bh 1.0
                bj:f32[] = sub bd bc
                bk:f32[] = mul bi bj
                bl:f32[] = add bk bc
                bm:f32[] = reshape[dimensions=None new_sizes=()] bl
                bn:f32[] = max bc bm
              in (bn,) }
            name=_uniform
          ] w 0.0 1.0
          bo:bool[] = lt y x
        in (bo,) }
      name=_bernoulli
    ] u 0.5
    bp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] v
    bq:bool[] = eq bp bp
    br:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq
    bs:i32[2] = cond[
      branches=(
        { lambda ; bt_:i32[] bu:i32[]. let
            bv:i32[] = sub 1 bu
            bw:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bv
            bx:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu
            by:i32[2] = concatenate[dimension=0] bw bx
          in (by,) }
        { lambda ; bz:i32[] ca_:i32[]. let
            cb:i32[] = sub 1 bz
            cc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bz
            cd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cb
            ce:i32[2] = concatenate[dimension=0] cc cd
          in (ce,) }
      )
      linear=(False, False)
    ] br a a
    cf:i32[1,9] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 9)] b
    cg:i32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] bs
    ch:bool[2,9] = eq cf cg
    ci:bool[2,3,3] = reshape[dimensions=None new_sizes=(2, 3, 3)] ch
    cj:bool[2,3,3,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(2, 3, 3, 1)
    ] ci
    ck:bool[3,3,2] = reshape[dimensions=(1, 2, 0, 3) new_sizes=(3, 3, 2)] cj
  in (bp, ck, c, d, d, e, l, f, a, b) }

Python 3.11.6 @ this PR

Package           Version
----------------- -------
jax               0.4.20
jaxlib            0.4.20
ml-dtypes         0.3.1
numpy             1.26.1
opt-einsum        3.3.0
pip               23.3.1
scipy             1.11.3
setuptools        68.2.2
svgwrite          1.4.3
typing_extensions 4.8.0
{ lambda a:i32[] b:i32[9] c:f32[2] d:bool[] e:bool[9] f:i32[]; g:u32[2]. let
    h:key<fry>[] = random_wrap[impl=fry] g
    i:key<fry>[2] = random_split[shape=(2,)] h
    j:u32[2,2] = random_unwrap i
    k:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] j
    l:u32[2] = squeeze[dimensions=(0,)] k
    m:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] j
    n:u32[2] = squeeze[dimensions=(0,)] m
    o:key<fry>[] = random_wrap[impl=fry] n
    p:key<fry>[2] = random_split[shape=(2,)] o
    q:u32[2,2] = random_unwrap p
    r:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] q
    _:u32[2] = squeeze[dimensions=(0,)] r
    s:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] q
    t:u32[2] = squeeze[dimensions=(0,)] s
    u:key<fry>[] = random_wrap[impl=fry] t
    v:bool[] = pjit[
      jaxpr={ lambda ; w:key<fry>[] x:f32[]. let
          y:f32[] = pjit[
            jaxpr={ lambda ; z:key<fry>[] ba:f32[] bb:f32[]. let
                bc:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] ba
                bd:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                be:u32[] = random_bits[bit_width=32 shape=()] z
                bf:u32[] = shift_right_logical be 9
                bg:u32[] = or bf 1065353216
                bh:f32[] = bitcast_convert_type[new_dtype=float32] bg
                bi:f32[] = sub bh 1.0
                bj:f32[] = sub bd bc
                bk:f32[] = mul bi bj
                bl:f32[] = add bk bc
                bm:f32[] = reshape[dimensions=None new_sizes=()] bl
                bn:f32[] = max bc bm
              in (bn,) }
            name=_uniform
          ] w 0.0 1.0
          bo:bool[] = lt y x
        in (bo,) }
      name=_bernoulli
    ] u 0.5
    bp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] v
    bq:bool[] = eq bp bp
    br:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq
    bs:i32[2] = cond[
      branches=(
        { lambda ; bt_:i32[] bu:i32[]. let
            bv:i32[] = sub 1 bu
            bw:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bv
            bx:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bu
            by:i32[2] = concatenate[dimension=0] bw bx
          in (by,) }
        { lambda ; bz:i32[] ca_:i32[]. let
            cb:i32[] = sub 1 bz
            cc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] bz
            cd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] cb
            ce:i32[2] = concatenate[dimension=0] cc cd
          in (ce,) }
      )
      linear=(False, False)
    ] br a a
    cf:i32[1,9] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 9)] b
    cg:i32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] bs
    ch:bool[2,9] = eq cf cg
    ci:bool[2,3,3] = reshape[dimensions=None new_sizes=(2, 3, 3)] ch
    cj:bool[2,3,3,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(2, 3, 3, 1)
    ] ci
    ck:bool[3,3,2] = reshape[dimensions=(1, 2, 0, 3) new_sizes=(3, 3, 2)] cj
  in (bp, ck, c, d, d, e, l, f, a, b) }

Copy link

codecov bot commented Nov 6, 2023

Codecov Report

Merging #1067 (1ae24be) into main (bce75ec) will increase coverage by 0.02%.
Report is 11 commits behind head on main.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #1067      +/-   ##
==========================================
+ Coverage   86.90%   86.92%   +0.02%     
==========================================
  Files          54       54              
  Lines        6047     6052       +5     
==========================================
+ Hits         5255     5261       +6     
+ Misses        792      791       -1     
Files Coverage Δ
pgx/_src/struct.py 66.66% <100.00%> (+3.50%) ⬆️

@sotetsuk sotetsuk changed the title Fix dataclass mutable defaults error. [dataclass] Fix dataclass mutable defaults error Nov 6, 2023
@sotetsuk sotetsuk merged commit a4b0d97 into sotetsuk:main Nov 6, 2023
3 of 4 checks passed
sotetsuk added a commit that referenced this pull request Nov 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants