Skip to content

Commit

Permalink
refactor: refactor series of visit_stream_node method
Browse files Browse the repository at this point in the history
  • Loading branch information
wenym1 committed Jan 26, 2025
1 parent 983dd18 commit 17ca3a2
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 50 deletions.
41 changes: 21 additions & 20 deletions src/common/src/util/stream_graph_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,11 @@ use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::{agg_call_state, StreamNode};

/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s recursively.
pub fn visit_stream_node<F>(stream_node: &mut StreamNode, mut f: F)
where
F: FnMut(&mut NodeBody),
{
fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
where
F: FnMut(&mut NodeBody),
{
pub fn visit_stream_node_mut(stream_node: &mut StreamNode, mut f: impl FnMut(&mut NodeBody)) {
visit_stream_node_cont_mut(stream_node, |stream_node| {
f(stream_node.node_body.as_mut().unwrap());
for input in &mut stream_node.input {
visit_inner(input, f);
}
}

visit_inner(stream_node, &mut f)
true
})
}

/// A utility for to accessing the [`StreamNode`] mutably. The returned bool is used to determine whether the access needs to continue.
Expand All @@ -56,6 +46,14 @@ where
visit_inner(stream_node, &mut f)
}

/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s recursively.
pub fn visit_stream_node(stream_node: &StreamNode, mut f: impl FnMut(&NodeBody)) {
visit_stream_node_cont(stream_node, |stream_node| {
f(stream_node.node_body.as_ref().unwrap());
true
})
}

/// A utility for to accessing the [`StreamNode`] immutably. The returned bool is used to determine whether the access needs to continue.
pub fn visit_stream_node_cont<F>(stream_node: &StreamNode, mut f: F)
where
Expand All @@ -78,11 +76,14 @@ where

/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s in a
/// [`StreamFragment`] recursively.
pub fn visit_fragment<F>(fragment: &mut StreamFragment, f: F)
where
F: FnMut(&mut NodeBody),
{
visit_stream_node(fragment.node.as_mut().unwrap(), f)
pub fn visit_fragment_mut(fragment: &mut StreamFragment, f: impl FnMut(&mut NodeBody)) {
visit_stream_node_mut(fragment.node.as_mut().unwrap(), f)
}

/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s in a
/// [`StreamFragment`] recursively.
pub fn visit_fragment(fragment: &StreamFragment, f: impl FnMut(&NodeBody)) {
visit_stream_node(fragment.node.as_ref().unwrap(), f)
}

/// Visit the tables of a [`StreamNode`].
Expand Down Expand Up @@ -279,7 +280,7 @@ pub fn visit_stream_node_tables_inner<F>(
}
};
if visit_child_recursively {
visit_stream_node(stream_node, visit_body)
visit_stream_node_mut(stream_node, visit_body)
} else {
visit_body(stream_node.node_body.as_mut().unwrap())
}
Expand Down
16 changes: 7 additions & 9 deletions src/meta/src/controller/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use anyhow::Context;
use itertools::Itertools;
use risingwave_common::bail;
use risingwave_common::hash::{VnodeCount, VnodeCountCompat, WorkerSlotId};
use risingwave_common::util::stream_graph_visitor::visit_stream_node;
use risingwave_common::util::stream_graph_visitor::visit_stream_node_mut;
use risingwave_meta_model::actor::ActorStatus;
use risingwave_meta_model::fragment::DistributionType;
use risingwave_meta_model::object::ObjectType;
Expand Down Expand Up @@ -227,7 +227,7 @@ impl CatalogController {
let stream_node = {
let actor_template = pb_actors.first().cloned().unwrap();
let mut stream_node = actor_template.nodes.unwrap();
visit_stream_node(&mut stream_node, |body| {
visit_stream_node_mut(&mut stream_node, |body| {
if let NodeBody::Merge(m) = body {
m.upstream_actor_id = vec![];
}
Expand All @@ -244,7 +244,7 @@ impl CatalogController {

let node = actor.nodes.as_mut().context("nodes are empty")?;

visit_stream_node(node, |body| {
visit_stream_node_mut(node, |body| {
if let NodeBody::Merge(m) = body {
let mut upstream_actor_ids = vec![];
swap(&mut m.upstream_actor_id, &mut upstream_actor_ids);
Expand Down Expand Up @@ -435,7 +435,7 @@ impl CatalogController {
let pb_nodes = {
let mut nodes = stream_node_template.clone();

visit_stream_node(&mut nodes, |body| {
visit_stream_node_mut(&mut nodes, |body| {
if let NodeBody::Merge(m) = body
&& let Some(upstream_actor_ids) =
upstream_fragment_actors.get(&(m.upstream_fragment_id as _))
Expand Down Expand Up @@ -1625,7 +1625,7 @@ mod tests {
use itertools::Itertools;
use risingwave_common::hash::{ActorMapping, VirtualNode, VnodeCount};
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_common::util::stream_graph_visitor::visit_stream_node;
use risingwave_common::util::stream_graph_visitor::{visit_stream_node, visit_stream_node_mut};
use risingwave_meta_model::actor::ActorStatus;
use risingwave_meta_model::fragment::DistributionType;
use risingwave_meta_model::{
Expand Down Expand Up @@ -1800,7 +1800,7 @@ mod tests {
let nodes = nodes.unwrap();
let actor_upstream_actor_ids =
upstream_actor_ids.get(&(actor_id as _)).cloned().unwrap();
visit_stream_node(&mut template_node, |body| {
visit_stream_node_mut(&mut template_node, |body| {
if let NodeBody::Merge(m) = body {
m.upstream_actor_id = actor_upstream_actor_ids
.get(&(m.upstream_fragment_id as _))
Expand Down Expand Up @@ -1978,9 +1978,7 @@ mod tests {

assert_eq!(mview_definition, "");

let mut pb_nodes = pb_nodes.unwrap();

visit_stream_node(&mut pb_nodes, |body| {
visit_stream_node(pb_nodes.as_ref().unwrap(), |body| {
if let PbNodeBody::Merge(m) = body {
let upstream_actor_ids = upstream_actor_ids
.get(&(m.upstream_fragment_id as _))
Expand Down
26 changes: 13 additions & 13 deletions src/meta/src/controller/streaming_job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use itertools::Itertools;
use risingwave_common::config::DefaultParallelism;
use risingwave_common::hash::VnodeCountCompat;
use risingwave_common::util::column_index_mapping::ColIndexMapping;
use risingwave_common::util::stream_graph_visitor::visit_stream_node;
use risingwave_common::util::stream_graph_visitor::{visit_stream_node, visit_stream_node_mut};
use risingwave_common::{bail, current_cluster_version};
use risingwave_connector::WithPropertiesExt;
use risingwave_meta_model::actor::ActorStatus;
Expand Down Expand Up @@ -1183,7 +1183,7 @@ impl CatalogController {
.await?
.map(|(id, node, upstream)| (id, node.to_protobuf(), upstream))
.ok_or_else(|| MetaError::catalog_id_not_found("fragment", fragment_id))?;
visit_stream_node(&mut stream_node, |body| {
visit_stream_node_mut(&mut stream_node, |body| {
if let PbNodeBody::Merge(m) = body
&& let Some((new_fragment_id, new_actor_ids)) =
fragment_replace_map.get(&m.upstream_fragment_id)
Expand Down Expand Up @@ -1356,7 +1356,7 @@ impl CatalogController {
fragments.retain_mut(|(_, fragment_type_mask, stream_node)| {
let mut found = false;
if *fragment_type_mask & PbFragmentTypeFlag::Source as i32 != 0 {
visit_stream_node(stream_node, |node| {
visit_stream_node_mut(stream_node, |node| {
if let PbNodeBody::Source(node) = node {
if let Some(node_inner) = &mut node.source_inner
&& node_inner.source_id == source_id as u32
Expand All @@ -1370,7 +1370,7 @@ impl CatalogController {
if is_fs_source {
// in older versions, there's no fragment type flag for `FsFetch` node,
// so we just scan all fragments for StreamFsFetch node if using fs connector
visit_stream_node(stream_node, |node| {
visit_stream_node_mut(stream_node, |node| {
if let PbNodeBody::StreamFsFetch(node) = node {
*fragment_type_mask |= PbFragmentTypeFlag::FsFetch as i32;
if let Some(node_inner) = &mut node.node_inner
Expand Down Expand Up @@ -1486,7 +1486,7 @@ impl CatalogController {
|fragment_type_mask: &mut i32, stream_node: &mut PbStreamNode| {
let mut found = false;
if *fragment_type_mask & PbFragmentTypeFlag::backfill_rate_limit_fragments() != 0 {
visit_stream_node(stream_node, |node| match node {
visit_stream_node_mut(stream_node, |node| match node {
PbNodeBody::StreamCdcScan(node) => {
node.rate_limit = rate_limit;
found = true;
Expand Down Expand Up @@ -1528,7 +1528,7 @@ impl CatalogController {
|fragment_type_mask: &mut i32, stream_node: &mut PbStreamNode| {
let mut found = false;
if *fragment_type_mask & PbFragmentTypeFlag::sink_rate_limit_fragments() != 0 {
visit_stream_node(stream_node, |node| {
visit_stream_node_mut(stream_node, |node| {
if let PbNodeBody::Sink(node) = node {
node.rate_limit = rate_limit;
found = true;
Expand All @@ -1551,7 +1551,7 @@ impl CatalogController {
|fragment_type_mask: &mut i32, stream_node: &mut PbStreamNode| {
let mut found = false;
if *fragment_type_mask & PbFragmentTypeFlag::dml_rate_limit_fragments() != 0 {
visit_stream_node(stream_node, |node| {
visit_stream_node_mut(stream_node, |node| {
if let PbNodeBody::Dml(node) = node {
node.rate_limit = rate_limit;
found = true;
Expand Down Expand Up @@ -1635,7 +1635,7 @@ impl CatalogController {
PbStreamActor {
actor_id,
fragment_id,
mut nodes,
nodes,
dispatcher,
upstream_actor_id,
vnode_bitmap,
Expand All @@ -1648,7 +1648,7 @@ impl CatalogController {
let mut actor_upstreams = BTreeMap::<FragmentId, BTreeSet<ActorId>>::new();
let mut new_actor_dispatchers = vec![];

if let Some(nodes) = &mut nodes {
if let Some(nodes) = &nodes {
visit_stream_node(nodes, |node| {
if let PbNodeBody::Merge(node) = node {
actor_upstreams
Expand Down Expand Up @@ -1922,15 +1922,15 @@ impl CatalogController {

let mut rate_limits = Vec::new();
for (fragment_id, job_id, fragment_type_mask, stream_node) in fragments {
let mut stream_node = stream_node.to_protobuf();
let stream_node = stream_node.to_protobuf();
let mut rate_limit = None;
let mut node_name = None;

visit_stream_node(&mut stream_node, |node| {
visit_stream_node(&stream_node, |node| {
match node {
// source rate limit
PbNodeBody::Source(node) => {
if let Some(node_inner) = &mut node.source_inner {
if let Some(node_inner) = &node.source_inner {
debug_assert!(
rate_limit.is_none(),
"one fragment should only have 1 rate limit node"
Expand All @@ -1940,7 +1940,7 @@ impl CatalogController {
}
}
PbNodeBody::StreamFsFetch(node) => {
if let Some(node_inner) = &mut node.node_inner {
if let Some(node_inner) = &node.node_inner {
debug_assert!(
rate_limit.is_none(),
"one fragment should only have 1 rate limit node"
Expand Down
6 changes: 3 additions & 3 deletions src/meta/src/model/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ impl StreamJobFragments {
/// Panics if not found.
pub fn union_fragment_for_table(&mut self) -> &mut Fragment {
let mut union_fragment_id = None;
for (fragment_id, fragment) in &mut self.fragments {
for actor in &mut fragment.actors {
if let Some(node) = &mut actor.nodes {
for (fragment_id, fragment) in &self.fragments {
for actor in &fragment.actors {
if let Some(node) = &actor.nodes {
visit_stream_node(node, |body| {
if let NodeBody::Union(_) = body {
if let Some(union_fragment_id) = union_fragment_id.as_mut() {
Expand Down
6 changes: 3 additions & 3 deletions src/meta/src/rpc/ddl_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,9 @@ impl DdlController {
}

// check if the union fragment is fully assigned.
for fragment in stream_job_fragments.fragments.values_mut() {
for actor in &mut fragment.actors {
if let Some(node) = &mut actor.nodes {
for fragment in stream_job_fragments.fragments.values() {
for actor in &fragment.actors {
if let Some(node) = &actor.nodes {
visit_stream_node(node, |node| {
if let NodeBody::Merge(merge_node) = node {
assert!(!merge_node.upstream_actor_id.is_empty(), "All the mergers for the union should have been fully assigned beforehand.");
Expand Down
2 changes: 1 addition & 1 deletion src/meta/src/stream/stream_graph/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl BuildingFragment {
let fragment_id = fragment.fragment_id;
let mut has_job = false;

stream_graph_visitor::visit_fragment(fragment, |node_body| match node_body {
stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
NodeBody::Materialize(materialize_node) => {
materialize_node.table_id = job_id;

Expand Down
2 changes: 1 addition & 1 deletion src/meta/src/stream/stream_graph/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ impl Scheduler {
// Vnode count requirements: if a fragment is going to look up an existing table,
// it must have the same vnode count as that table.
for (&id, fragment) in graph.building_fragments() {
visit_fragment(&mut (*fragment).clone(), |node| {
visit_fragment(fragment, |node| {
use risingwave_pb::stream_plan::stream_node::NodeBody;
let vnode_count = match node {
NodeBody::StreamScan(node) => {
Expand Down

0 comments on commit 17ca3a2

Please sign in to comment.