Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove NameSpace usage #434

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer
from flag_gems.utils.shape_utils import restride_dim

from .scatter import scatter_
Expand Down Expand Up @@ -44,28 +44,18 @@ def generate_gather_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("inp,")
function_ns.create_name("inp")
code.writeline("out,")
function_ns.create_name("out")
code.writeline("index,")
function_ns.create_name("index")

for i in range(rank):
function_ns.create_name(f"inp_stride_{i}")
stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for inp")

for i in range(rank):
function_ns.create_name(f"index_stride_{i}")
stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for index")

for i in range(rank):
function_ns.create_name(f"index_shape_{i}")
shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape for index")

Expand Down
15 changes: 1 addition & 14 deletions src/flag_gems/ops/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer


# --------------------------- padding wrapper genration -----------------------------------
Expand Down Expand Up @@ -222,42 +222,29 @@ def generate_pad_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("in0_ptr")

code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("out0_ptr")

if rank > 0:
# shape for inputs
for j in range(rank):
function_ns.create_name(f"x_shape{j}")
shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
code.writeline(f"{shape_args}, # shape for x")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"in_strides{j}")
stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # stride for x")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"out_strides{j}")
stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # stride for out")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"valid_dim{j}_start")
stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
code.writeline(f"{stride_args}, # valid dim start")

# shape for inputs
for j in range(rank):
function_ns.create_name(f"valid_dim{j}_end")
stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
code.writeline(f"{stride_args}, # valid dim end")

Expand Down
29 changes: 1 addition & 28 deletions src/flag_gems/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer


# --------------------------- repeat wrapper genration -----------------------------------
Expand Down Expand Up @@ -220,87 +220,67 @@ def generate_repeat_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
# signature: inputs ptrs & non tensor inputs
code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("in0_ptr")

# signature: output ptrs
code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
function_ns.create_name("out0_ptr")

# signature: strides, for each tensor arguments
# only add this arguments when rank > 0
if rank > 0:
# strides for inputs
for j in range(rank):
function_ns.create_name(f"in0_stride{j}")
stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # strides for in0")

# strides for outputs
for j in range(rank):
function_ns.create_name(f"out0_stride{j}")
stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank))
code.writeline(f"{stride_args}, # strides for out0")

# task space, used to reconstruct multi index
task_space_args = ", ".join(f"s{i}: int" for i in range(rank))
for i in range(rank):
function_ns.create_name(f"s{i}")
code.writeline(f"{task_space_args}, # task_space")

task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank))
for i in range(rank):
function_ns.create_name(f"in_s{i}")
code.writeline(
f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
)

# number of tasks, used to compute mask
code.writeline("num_tasks: int,")
function_ns.create_name("num_tasks")

# tile size & tiles_per_cta, gsl style
if rank > 0:
code.writeline("tiles_per_cta,")
function_ns.create_name("tiles_per_cta")

code.writeline("tile_size: tl.constexpr,")
function_ns.create_name("tile_size")

code.writeline("one_tile_per_cta: tl.constexpr,")
function_ns.create_name("one_tile_per_cta")
code.writeline("):")

with code.indent():
# get pid
code.writeline("# task id & masking")
pid_stmt = "pid = tle.program_id(0)"
code.writeline(pid_stmt)
function_ns.create_name("pid")

code.writeline("num_ctas = tle.num_programs(0)")
function_ns.create_name("num_ctas")

# get tid (a.k.a task id)
tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
code.writeline(tid_stmt)
function_ns.create_name("init_tid")

# one-tile-per-cta, monolithic kernel style
code.writeline("if one_tile_per_cta: # monolitic kernel style")
with code.indent():
tid_stmt = "tid = init_tid"
code.writeline(tid_stmt)
function_ns.create_name("tid")

# only apply masking when rank > 0
# since we only load a value instead of a block of values when the rank is 0
mask_stmt: str = "mask = tid < num_tasks"
code.writeline(mask_stmt)
function_ns.create_name("mask")
code.newline()

# reconstruct multi index
Expand All @@ -311,7 +291,6 @@ def generate_repeat_kernel(
code.writeline(f"tid //= s{i}")
else:
code.writeline(f"i{i} = tid")
function_ns.create_name(f"{i}")
code.newline()

# loads
Expand All @@ -321,7 +300,6 @@ def generate_repeat_kernel(
)
ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
function_ns.create_name("in0") # add to the namespace
code.writeline(load_stmt)
code.newline()

Expand All @@ -341,17 +319,14 @@ def generate_repeat_kernel(
code.writeline("else: # grid-stride-loop style kernel")
with code.indent():
code.writeline("for j in range(0, tiles_per_cta):")
function_ns.create_name("j")
with code.indent():
tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
code.writeline(tid_stmt)
function_ns.create_name("tid")

# only apply masking when rank > 0
# since we only load a value instead of a block of values when the rank is 0
mask_stmt: str = "mask = tid < num_tasks"
code.writeline(mask_stmt)
function_ns.create_name("mask")
code.newline()

# reconstruct multi index
Expand All @@ -362,7 +337,6 @@ def generate_repeat_kernel(
code.writeline(f"tid //= s{i}")
else:
code.writeline(f"i{i} = tid")
function_ns.create_name(f"{i}")
code.newline()

# loads
Expand All @@ -372,7 +346,6 @@ def generate_repeat_kernel(
)
ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
function_ns.create_name("in0") # add to the namespace
code.writeline(load_stmt)
code.newline()

Expand Down
15 changes: 1 addition & 14 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.code_utils import IndentedBuffer
from flag_gems.utils.shape_utils import has_internal_overlapping, restride_dim


Expand Down Expand Up @@ -66,35 +66,22 @@ def generate_scatter_kernel(

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("src_strided,")
function_ns.create_name("src_strided")
code.writeline("index,")
function_ns.create_name("index")
code.writeline("inp,")
function_ns.create_name("inp")
code.writeline("out,")
function_ns.create_name("out")

for i in range(rank):
function_ns.create_name(f"inp_stride_{i}")
stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for inp")

for i in range(rank):
function_ns.create_name(f"index_stride_{i}")
stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for index")

for i in range(rank):
function_ns.create_name(f"src_stride_{i}")
stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for src")

for i in range(rank):
function_ns.create_name(f"shape_{i}")
shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape")
code.writeline("inp_size_dim,")
Expand Down
Loading
Loading