Skip to content

Commit

Permalink
Remove NameSpace usage
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Jan 23, 2025
1 parent 73acbe6 commit 3072da7
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 95 deletions.
12 changes: 1 addition & 11 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer
from flag_gems.utils.shape_utils import restride_dim

from .scatter import scatter_
Expand Down Expand Up @@ -44,28 +44,18 @@ def generate_gather_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("inp,")
function_ns.create_name("inp")
code.writeline("out,")
function_ns.create_name("out")
code.writeline("index,")
function_ns.create_name("index")

for i in range(rank):
function_ns.create_name(f"inp_stride_{i}")
stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for inp")

for i in range(rank):
function_ns.create_name(f"index_stride_{i}")
stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for index")

for i in range(rank):
function_ns.create_name(f"index_shape_{i}")
shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape for index")

Expand Down
15 changes: 1 addition & 14 deletions src/flag_gems/ops/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer


# --------------------------- padding wrapper genration -----------------------------------
Expand Down Expand Up @@ -222,42 +222,29 @@ def generate_pad_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("in0_ptr")

code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("out0_ptr")

if rank > 0:
# shape for inputs
for j in range(rank):
function_ns.create_name(f"x_shape{j}")
shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
code.writeline(f"{shape_args}, # shape for x")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"in_strides{j}")
stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # stride for x")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"out_strides{j}")
stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # stride for out")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"valid_dim{j}_start")
stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
code.writeline(f"{stride_args}, # valid dim start")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"valid_dim{j}_end")
stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
code.writeline(f"{stride_args}, # valid dim end")

Expand Down
29 changes: 1 addition & 28 deletions src/flag_gems/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer


# --------------------------- repeat wrapper genration -----------------------------------
Expand Down Expand Up @@ -220,87 +220,67 @@ def generate_repeat_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
# signature: inputs ptrs & non tensor inputs
code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("in0_ptr")

# signature: output ptrs
code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("out0_ptr")

# signature: strides, for each tensor arguments
# only add this arguments when rank > 0
if rank > 0:
# strides for inputs
for j in range(rank):
function_ns.create_name(f"in0_stride{j}")
stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # strides for in0")

# strides for outputs
for j in range(rank):
function_ns.create_name(f"out0_stride{j}")
stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # strides for out0")

# task space, used to reconstruct multi index
task_space_args = ", ".join(f"s{i}: int" for i in range(rank))
for i in range(rank):
function_ns.create_name(f"s{i}")
code.writeline(f"{task_space_args}, # task_space")

task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank))
for i in range(rank):
function_ns.create_name(f"in_s{i}")
code.writeline(
f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
)

# number of tasks, used to compute mask
code.writeline("num_tasks: int,")
function_ns.create_name("num_tasks")

# tile size & tiles_per_cta, gsl style
if rank > 0:
code.writeline("tiles_per_cta,")
function_ns.create_name("tiles_per_cta")

code.writeline("tile_size: tl.constexpr,")
function_ns.create_name("tile_size")

code.writeline("one_tile_per_cta: tl.constexpr,")
function_ns.create_name("one_tile_per_cta")
code.writeline("):")

with code.indent():
# get pid
code.writeline("# task id & masking")
pid_stmt = "pid = tle.program_id(0)"
code.writeline(pid_stmt)
function_ns.create_name("pid")

code.writeline("num_ctas = tle.num_programs(0)")
function_ns.create_name("num_ctas")

# get tid (a.k.a task id)
tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
code.writeline(tid_stmt)
function_ns.create_name("init_tid")

# one-tile-per-cta, monolithic kernel style
code.writeline("if one_tile_per_cta: # monolitic kernel style")
with code.indent():
tid_stmt = "tid = init_tid"
code.writeline(tid_stmt)
function_ns.create_name("tid")

# only apply masking when rank > 0
# since we only load a value instead of a block of values when the rank is 0
mask_stmt: str = "mask = tid < num_tasks"
code.writeline(mask_stmt)
function_ns.create_name("mask")
code.newline()

# reconstruct multi index
Expand All @@ -311,7 +291,6 @@ def generate_repeat_kernel(
code.writeline(f"tid //= s{i}")
else:
code.writeline(f"i{i} = tid")
function_ns.create_name(f"{i}")
code.newline()

# loads
Expand All @@ -321,7 +300,6 @@ def generate_repeat_kernel(
)
ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
function_ns.create_name("in0") # add to the namespace
code.writeline(load_stmt)
code.newline()

Expand All @@ -341,17 +319,14 @@ def generate_repeat_kernel(
code.writeline("else: # grid-stride-loop style kernel")
with code.indent():
code.writeline("for j in range(0, tiles_per_cta):")
function_ns.create_name("j")
with code.indent():
tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
code.writeline(tid_stmt)
function_ns.create_name("tid")

# only apply masking when rank > 0
# since we only load a value instead of a block of values when the rank is 0
mask_stmt: str = "mask = tid < num_tasks"
code.writeline(mask_stmt)
function_ns.create_name("mask")
code.newline()

# reconstruct multi index
Expand All @@ -362,7 +337,6 @@ def generate_repeat_kernel(
code.writeline(f"tid //= s{i}")
else:
code.writeline(f"i{i} = tid")
function_ns.create_name(f"{i}")
code.newline()

# loads
Expand All @@ -372,7 +346,6 @@ def generate_repeat_kernel(
)
ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
function_ns.create_name("in0") # add to the namespace
code.writeline(load_stmt)
code.newline()

Expand Down
15 changes: 1 addition & 14 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer
from flag_gems.utils.shape_utils import has_internal_overlapping, restride_dim


Expand Down Expand Up @@ -66,35 +66,22 @@ def generate_scatter_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("src_strided,")
function_ns.create_name("src_strided")
code.writeline("index,")
function_ns.create_name("index")
code.writeline("inp,")
function_ns.create_name("inp")
code.writeline("out,")
function_ns.create_name("out")

for i in range(rank):
function_ns.create_name(f"inp_stride_{i}")
stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for inp")

for i in range(rank):
function_ns.create_name(f"index_stride_{i}")
stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for index")

for i in range(rank):
function_ns.create_name(f"src_stride_{i}")
stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for src")

for i in range(rank):
function_ns.create_name(f"shape_{i}")
shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape")
code.writeline("inp_size_dim,")
Expand Down
Loading

0 comments on commit 3072da7

Please sign in to comment.