Skip to content

Commit

Permalink
Arithmetic optimizers check device equality.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-choncholas committed Jan 14, 2025
1 parent 973d81d commit db4434b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tf_shell/cc/optimizers/ct_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ bool FindAddOrSub(utils::MutableGraphView& graph_view, int node_index,
auto const& inner_fanin_0 = inner_node_view->GetRegularFanin(0);
auto const* inner_context_node_view = inner_fanin_0.node_view();

// If the contexts do not match, the pattern should not be matched..
// If the contexts do not match, the pattern should not be matched.
if (context_node_view->node_index() != inner_context_node_view->node_index())
return false;

// If the devices do not match, the pattern should not be matched.
if (outer_node_def->device() != inner_node_view->node()->device())
return false;

auto const& inner_fanin_1 = inner_node_view->GetRegularFanin(1);
auto const* inner_ct_node_view = inner_fanin_1.node_view();

Expand Down
6 changes: 6 additions & 0 deletions tf_shell/cc/optimizers/pt_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ bool FindPtPt(utils::MutableGraphView& graph_view, int node_index,
return false;
}

// Check all ops have the same device.
if (outer_node_def->device() != input_a_node_def->device() ||
outer_node_def->device() != input_b_node_def->device()) {
return false;
}

auto const& encode_b_fanin_0 = input_b_node_view->GetRegularFanin(1);
auto const* tf_input_b_node_view = encode_b_fanin_0.node_view();

Expand Down

0 comments on commit db4434b

Please sign in to comment.