Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize map unions to avoid building long lists #14215

Merged
merged 6 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 117 additions & 38 deletions lib/elixir/lib/module/types/descr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,115 @@ defmodule Module.Types.Descr do

defp map_only?(descr), do: empty?(Map.delete(descr, :map))

# Union is list concatenation
defp map_union(dnf1, dnf2), do: dnf1 ++ (dnf2 -- dnf1)
defp map_union(dnf1, dnf2) do
# Union is just concatenation, but we rely on some optimization strategies to
# avoid the list to grow when possible

# first pass trying to identify patterns where two maps can be fused as one
with [map1] <- dnf1,
[map2] <- dnf2,
optimized when optimized != nil <- maybe_optimize_map_union(map1, map2) do
[optimized]
else
# otherwise we just concatenate and remove structural duplicates
_ -> dnf1 ++ (dnf2 -- dnf1)
end
end

defp maybe_optimize_map_union({tag1, pos1, []} = map1, {tag2, pos2, []} = map2) do
case map_union_optimization_strategy(tag1, pos1, tag2, pos2) do
:all_equal ->
map1

:any_map ->
{:open, %{}, []}

{:one_key_difference, key, v1, v2} ->
new_pos = Map.put(pos1, key, union(v1, v2))
{tag1, new_pos, []}

:left_subtype_of_right ->
map2

:right_subtype_of_left ->
map1

nil ->
nil
end
end

defp maybe_optimize_map_union(_, _), do: nil

defp map_union_optimization_strategy(tag1, pos1, tag2, pos2)
defp map_union_optimization_strategy(tag, pos, tag, pos), do: :all_equal
defp map_union_optimization_strategy(:open, empty, _, _) when empty == %{}, do: :any_map
defp map_union_optimization_strategy(_, _, :open, empty) when empty == %{}, do: :any_map

defp map_union_optimization_strategy(tag, pos1, tag, pos2)
when map_size(pos1) == map_size(pos2) do
:maps.iterator(pos1)
|> :maps.next()
|> do_map_union_optimization_strategy(pos2, :all_equal)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not handle cases where the open map on the left has some extra fields that are set to if_set(term()), and the right map is closed.

Example:

assert union(
        open_map(a: if_set(term())),
        closed_map([])
      ) == open_map(a: if_set(term()))

Similarly, what if we have a larger (in size) open map as pos1, but which is a supertype of pos2? Then the only tried strategy will be l.1340 which leads to :left_subtype_of_right.

Example:

 assert union(
        open_map(a: if_set(term()), b: number()),
        open_map(b: integer())
      ) == open_map(a: if_set(term()), b: number())

I don't think those are case that necessarily need to be covered, but adding those tests to highlight it would prevent us discovering this again.

Copy link
Author

@sabiwara sabiwara Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was totally forgetting about if_set.

I don't think those are case that necessarily need to be covered

Yeah this is an optimization supposed to deal with some "obvious" cases that happen frequently, so it might be OK not to catch all cases (we're not dealing with negs either).

But in this case it might be possible to implement in the current pass with something like:

  • if one key is only on the side of the supertype and its value is if_set, continue inferring this supertype relation
  • if we can switch to the supertype strategy, do it
  • otherwise bail

The map size issue is a real problem though... Perhaps by changing the internal representation to store if_set as part of a different map, we can easily compute the size of the required map, and have a separate pass for optional keys?
This might be overkill for this particular use case, but if this new representation can simplify a bunch of other places such as subtyping etc it might be worth considering.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The map size issue is a real problem though...

I think it is fine because our goal is to traverse the smallest map for performance. The full algorithm does require traversing both sides but the point here is precisely to not implement the full algorithm. :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

I was more thinking if we get bottlenecks in the future due to if_set and if there was a way to optimize them in other parts too, but it's best to avoid speculation and to wait if real world slow/pathological cases to show up, and iterate then.

defp map_union_optimization_strategy(:open, pos1, _, pos2)
when map_size(pos1) <= map_size(pos2) do
:maps.iterator(pos1)
|> :maps.next()
|> do_map_union_optimization_strategy(pos2, :right_subtype_of_left)
end

defp map_union_optimization_strategy(_, pos1, :open, pos2)
when map_size(pos1) >= map_size(pos2) do
:maps.iterator(pos2)
|> :maps.next()
|> do_map_union_optimization_strategy(pos1, :right_subtype_of_left)
|> case do
:right_subtype_of_left -> :left_subtype_of_right
nil -> nil
end
end

defp map_union_optimization_strategy(_, _, _, _), do: nil

defp do_map_union_optimization_strategy(:none, _, status), do: status

defp do_map_union_optimization_strategy({key, v1, iterator}, pos2, status) do
with %{^key => v2} <- pos2,
next_status when next_status != nil <- map_union_next_strategy(key, v1, v2, status) do
do_map_union_optimization_strategy(:maps.next(iterator), pos2, next_status)
else
_ -> nil
end
end

defp map_union_next_strategy(key, v1, v2, status)

# structurally equal values do not impact the ongoing strategy
defp map_union_next_strategy(_key, same, same, status), do: status

defp map_union_next_strategy(key, v1, v2, :all_equal) do
if key != :__struct__, do: {:one_key_difference, key, v1, v2}
end

defp map_union_next_strategy(_key, v1, v2, {:one_key_difference, _, d1, d2}) do
# we have at least two key differences now, we switch strategy
# if both are subtypes in one direction, keep checking
cond do
subtype?(d1, d2) and subtype?(v1, v2) -> :left_subtype_of_right
subtype?(d2, d1) and subtype?(v2, v1) -> :right_subtype_of_left
true -> nil
end
end

defp map_union_next_strategy(_key, v1, v2, :left_subtype_of_right) do
if subtype?(v1, v2), do: :left_subtype_of_right
end

defp map_union_next_strategy(_key, v1, v2, :right_subtype_of_left) do
if subtype?(v2, v1), do: :right_subtype_of_left
end

# Given two unions of maps, intersects each pair of maps.
defp map_intersection(dnf1, dnf2) do
Expand Down Expand Up @@ -1761,49 +1868,21 @@ defmodule Module.Types.Descr do

defp map_non_negated_fuse(maps) do
Enum.reduce(maps, [], fn map, acc ->
case Enum.split_while(acc, &non_fusible_maps?(map, &1)) do
{_, []} ->
[map | acc]

{others, [match | rest]} ->
fused = map_non_negated_fuse_pair(map, match)
others ++ [fused | rest]
end
fuse_with_first_fusible(map, acc)
end)
end

# Two maps are fusible if they differ in at most one element.
# Given they are of the same size, the side you traverse is not important.
defp non_fusible_maps?({_, fields1, []}, {_, fields2, []}) do
not fusible_maps?(Map.to_list(fields1), fields2, 0)
end

defp fusible_maps?([{:__struct__, value} | rest], fields, count) do
case Map.fetch!(fields, :__struct__) do
^value -> fusible_maps?(rest, fields, count)
_ -> false
end
end
defp fuse_with_first_fusible(map, []), do: [map]

defp fusible_maps?([{key, value} | rest], fields, count) do
case Map.fetch!(fields, key) do
^value -> fusible_maps?(rest, fields, count)
_ when count == 1 -> false
_ when count == 0 -> fusible_maps?(rest, fields, count + 1)
defp fuse_with_first_fusible(map, [candidate | rest]) do
if fused = maybe_optimize_map_union(map, candidate) do
# we found a fusible candidate, we're done
[fused | rest]
else
[candidate | fuse_with_first_fusible(map, rest)]
end
end

defp fusible_maps?([], _fields, _count), do: true

defp map_non_negated_fuse_pair({tag, fields1, []}, {_, fields2, []}) do
fields =
symmetrical_merge(fields1, fields2, fn _k, v1, v2 ->
if v1 == v2, do: v1, else: union(v1, v2)
end)

{tag, fields, []}
end

# If all fields are the same except one, we can optimize map difference.
defp map_all_but_one?(tag1, fields1, tag2, fields2) do
keys1 = Map.keys(fields1)
Expand Down
40 changes: 40 additions & 0 deletions lib/elixir/test/elixir/module/types/descr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,46 @@ defmodule Module.Types.DescrTest do
assert union(difference(list(term()), list(integer())), list(integer()))
|> equal?(list(term()))
end

test "optimizations" do
# The tests are checking the actual implementation, not the semantics.
# This is why we are using structural comparisons.
# It's fine to remove these if the implementation changes, but breaking
# these might have an important impact on compile times.

# Optimization one: same tags, all but one key are structurally equal
assert union(
open_map(a: float(), b: atom()),
open_map(a: integer(), b: atom())
) == open_map(a: union(float(), integer()), b: atom())

assert union(
closed_map(a: float(), b: atom()),
closed_map(a: integer(), b: atom())
) == closed_map(a: union(float(), integer()), b: atom())

# Optimization two: we can tell that one map is a trivial subtype of the other:

assert union(
closed_map(a: term(), b: term()),
closed_map(a: float(), b: binary())
) == closed_map(a: term(), b: term())

assert union(
open_map(a: term()),
closed_map(a: float(), b: binary())
) == open_map(a: term())

assert union(
closed_map(a: float(), b: binary()),
open_map(a: term())
) == open_map(a: term())

assert union(
closed_map(a: term(), b: tuple([term(), term()])),
closed_map(a: float(), b: tuple([atom(), binary()]))
) == closed_map(a: term(), b: tuple([term(), term()]))
end
end

describe "intersection" do
Expand Down
Loading