diff --git a/assets/Customflash2_a100_fwd_bwd_benchmark.png b/assets/Customflash2_a100_fwd_bwd_benchmark.png new file mode 100644 index 000000000..281f52420 Binary files /dev/null and b/assets/Customflash2_a100_fwd_bwd_benchmark.png differ diff --git a/autotuner/arch/A100.py b/autotuner/arch/A100.py new file mode 100644 index 000000000..2c3a16839 --- /dev/null +++ b/autotuner/arch/A100.py @@ -0,0 +1,15 @@ +from .arch_base import Arch +class A100(Arch): + def __init__(self): + self.reg_cap = 65536 # 32768 + self.smem_cap = 163*1024 # 164*1024 + self.compute_max_core = 108 + self.warp_size = 32 + self.sm_partition = 4 + self.transaction_size = [32, 128] # in bytes + self.max_smem_usage = 164 * 1024 + self.bandwidth = [1319, 16308] + self.platform = "CUDA" + self.compute_capability = "80" + self.cutlass_mma = [16, 8, 16] + self.register_per_thread = 255 diff --git a/autotuner/arch/RTX4090.py b/autotuner/arch/RTX4090.py new file mode 100644 index 000000000..bae3c291e --- /dev/null +++ b/autotuner/arch/RTX4090.py @@ -0,0 +1,15 @@ +from .arch_base import Arch +class RTX4090(Arch): + def __init__(self): + self.reg_cap = 65536 # 32768 + self.smem_cap = 100*1024 # 164*1024 + self.compute_max_core = 128 + self.warp_size = 32 + self.sm_partition = 4 + self.transaction_size = [32, 128] # in bytes + self.max_smem_usage = 100 * 1024 + self.bandwidth = [1008, 0] # TODO: 1 + self.platform = "CUDA" + self.compute_capability = "89" + self.cutlass_mma = [16, 8, 16] + self.register_per_thread = 255 \ No newline at end of file diff --git a/autotuner/arch/__init__.py b/autotuner/arch/__init__.py new file mode 100644 index 000000000..9ba8ec3a7 --- /dev/null +++ b/autotuner/arch/__init__.py @@ -0,0 +1,3 @@ +from .arch_base import Arch +from .A100 import * +from .RTX4090 import * diff --git a/autotuner/arch/arch_base.py b/autotuner/arch/arch_base.py new file mode 100644 index 000000000..74d6144d0 --- /dev/null +++ b/autotuner/arch/arch_base.py @@ -0,0 +1,13 @@ +class Arch: + def __init__(self) -> None: + self.reg_cap = 0 + self.smem_cap = 0 + self.compute_max_core = 0 + self.warp_size = 0 + self.sm_partition = 0 + self.transaction_size = [0, 0] + self.max_smem_usage = 0 + self.bandwidth = [0, 0] + self.platform = "unknown" + self.compute_capability = "unknown" + self.register_per_thread = 0 diff --git a/autotuner/base_tunner.py b/autotuner/base_tunner.py new file mode 100644 index 000000000..c2111dfef --- /dev/null +++ b/autotuner/base_tunner.py @@ -0,0 +1,247 @@ +import ctypes +import os +from concurrent.futures import ThreadPoolExecutor +# import multiprocessing +# from functools import partial +import tempfile +import subprocess +import importlib.util + +import ctypes +import torch +from configs import BaseConfig, supported_configs + +import pprint +import json + +import time + +from code_emitter import CodeEmitter, ShapeConfig, ProfileConfig +from profile_attn import profile_fwd + + + + + +class CompileResult: + def __init__(self, config: BaseConfig, lib_name: str) -> None: + self.config = config + self.lib_name = lib_name + +def _create_code_for_profiling(config): + profile_code_path = os.path.join(config.template_dir , config.operation, "profile_code.py") + + spec = importlib.util.spec_from_file_location("ProfileCode", profile_code_path) + foo = importlib.util.module_from_spec(spec) + spec.loader.exec_module(foo) + # from template.flash_kernels.retnet.regfuse.profile_code import profile_code + # return profile_code.format(Br=config.Br, Bc=config.Bc, Kd=config.Kd, D=config.D, unrollLastIter=int(config.unrollLastIter), BlockKSmem=config.BlockKSmem, num_stages_qk=config.num_stages_qk, num_stages_mask=config.num_stages_mask, BlockKSmem2=config.BlockKSmem2, num_stages_v=config.num_stages_v, Nthreads=config.Nthreads) + # from template.flash_kernels.retnet.smemfuse.profile_code import profile_code + # return profile_code.format(Br=config.Br, Bc=config.Bc, Kd=config.Kd, D=config.D, unrollLastIter=int(config.unrollLastIter), BlockKSmem=config.BlockKSmem, num_stages_qk=config.num_stages_qk, num_stages_mask=config.num_stages_mask, BlockKSmem2=config.BlockKSmem2, num_stages_v=config.num_stages_v, Nthreads=config.Nthreads, warps_mma1_n=config.warps_mma1_n, warps_mma_n=config.warps_mma_n) + return foo.profile_code.format_map(config.__dict__) + +# def _compile(config, arch, temp_dir:str, timeout: float = None): +# ## compile + +# profiling_code = _create_code_for_profiling(config) +# src = tempfile.NamedTemporaryFile(mode="w",suffix=".cu", delete=True, dir=temp_dir) +# lib_name = src.name.replace(".cu", ".so") +# compute_version = arch.compute_capability +# cutlass_dir = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/include") +# csrc_dir = os.path.join(os.path.dirname(__file__), "../../csrc") +# if config.fuse_type == "register": +# template_dir = os.path.join(config.template_dir , "regfuse/") +# elif config.fuse_type == "shared": +# template_dir = os.path.join(config.template_dir , "smemfuse/") +# else: # bwd +# template_dir = config.template_dir +# command = ["nvcc","-std=c++17","-O3","--use_fast_math","--expt-relaxed-constexpr","--disable-warnings", "--compiler-options", "'-fPIC'", "--shared", src.name, "-lcuda", +# f"-gencode=arch=compute_{compute_version},code=sm_{compute_version}", +# f"-I{cutlass_dir}",f"-I{template_dir}",f"-I{csrc_dir}", "-o", lib_name] +# src.write(profiling_code) +# src.flush() +# try: +# ret = subprocess.run(command, timeout=timeout) +# except subprocess.TimeoutExpired: +# return None +# if ret.returncode != 0: +# return None +# return CompileResult(config,lib_name) + +class BaseTunner: + def __init__(self, arch, torch_array: list, op_name, shape_config: ShapeConfig, profile_config: ProfileConfig, tempdir): + self.arch = arch + self.torch_array = torch_array + self.Br_list = [32, 64, 96, 128, 160, 192, 224, 256] # [32, 64, 128, 256] + self.Bc_list = [32, 64, 96, 128, 160, 192, 224, 256] # [32, 64, 128, 256] + + self.template_dir = "autotuner/template" + self.op_name = op_name + # TODO: workaround for dropout_p + self.cache_path = os.path.join(os.path.dirname(__file__), "./cache/", str(profile_config.dropout_p!=0)) + self.problem_key = { + "dim_qk": torch_array[0].shape[-1], + "dim_v": torch_array[2].shape[-1] + } + assert torch_array[0].shape[-1] == shape_config.Kd + assert torch_array[2].shape[-1] == shape_config.D + self.shape_config = shape_config + self.profile_config = profile_config + self.tempdir = tempdir + + def compile(self, configs:list, timeout: float = None): + temp_dir = self.tempdir + code_emitter = CodeEmitter(self.template_dir, temp_dir) + code_emitter.generate_code(self.shape_config, configs) + + + def profile(self, config:BaseConfig, repeat=30, load_only=False) -> float: + spec = importlib.util.spec_from_file_location("flash_attn_func", self.tempdir+"/"+config.output_dir+"/flash_attn_profile_interface.py") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + flash_attn_func = mod.flash_attn_func + if load_only: + return None + latency = profile_fwd(flash_attn_func, self.shape_config.Kd, self.shape_config.D, batch_size=self.profile_config.batch_size, seqlen=self.profile_config.seqlen_q, nheads=self.profile_config.nheads, dropout_p=self.profile_config.dropout_p,is_bf16=self.shape_config.is_bf16, causal=self.shape_config.is_causal, device=self.profile_config.device, repeats=repeat) + if latency < 0: + latency = 1e8 + # remove lib + # subprocess.run(["rm", libname], check=True) + return latency + + def get_tuned_configs(self): + dim_qk = self.problem_key["dim_qk"] + dim_v = self.problem_key["dim_v"] + configs = [] + for Br in self.Br_list: + for Bc in self.Bc_list: + cur_configs = self.generate_configs(Br,Bc,dim_qk,dim_v) + for cur_config in cur_configs: + if self.op_name == "flash_fwd" and self.validate_register_fuse(cur_config): + configs.append(cur_config) + else: # BWD + if self.validate_kernel(cur_config): + configs.append(cur_config) + return configs + + def tune(self, log_path="./logs/"): + st = time.time() + + dim_qk = self.problem_key["dim_qk"] + dim_v = self.problem_key["dim_v"] + + best_config = self.check_cache() + if best_config is not None: + # print("Best config found in cache: ") + # pprint.pprint(best_config) + return best_config + + configs = self.get_tuned_configs() + + # print configs + print("Configs to be tuned: ") + for config in configs: + # print(config) + pprint.pprint(config) + + + # cresults = self.compile(configs,src_dir.name,timeout=1200) + # cresults = self.compile_parallel(configs,src_dir.name,timeout=120) + self.compile(configs,timeout=120) + + # warm up (parallel compile module) + # module name must be different in api.py + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + latencys = executor.map(self.profile, configs, [1 for _ in range(len(configs))], [True for _ in range(len(configs))]) + # with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + # latencys = executor.map(_profile,[self.tempdir for _ in range(len(configs))],[self.shape_config for _ in range(len(configs))], configs, ["cuda:0" for _ in range(len(configs))], [1 for _ in range(len(configs))]) + # multiprocessing.set_start_method('spawn', force=True) + # pool = multiprocessing.Pool(os.cpu_count()) + # outs = pool.map(partial(self.profile, repeat=1), configs) + + profile_dict = {} + latency = 1e8 + best_config = None + for config in configs: + lib_latency = self.profile(config) + if lib_latency == 1e8: + # print(cresult.config) + pprint.pprint(config) + print("profile runtime error") + if lib_latency < latency: + latency = lib_latency + best_config = config + profile_dict[config] = lib_latency + + end = time.time() + + print("##########################################################") + print("Operation type: ", best_config.operation) + print("Best config: ")# , best_config) + pprint.pprint(best_config) + print("Latency: ", latency) + + file_name = "profile_result_{}_{}_{}_p{}_{}_{}_{}_c{}.txt".format(best_config.operation,dim_qk, dim_v, self.profile_config.batch_size, self.profile_config.seqlen_q, self.profile_config.nheads, self.profile_config.dropout_p,self.shape_config.is_causal) + os.makedirs(log_path,exist_ok=True) + with open(os.path.join(log_path,file_name),"a") as f: + for config in profile_dict: + f.write(repr(config)+"\n") + f.write(str(profile_dict[config])+"\n") + f.write("\n") + f.write("best config: \n") + f.write(repr(best_config)+"\n") + f.write(str(latency)+"\n") + f.write("\nsearch time: "+str(end-st)+"s" + "\n\n") + + cache_path = self.cache_path + os.makedirs(cache_path,exist_ok=True) + with open(os.path.join(cache_path,"best_config_{}_{}_{}.json".format(self.op_name,dim_qk, dim_v)),"w") as f: + json.dump(best_config.__dict__,f) + + return best_config + + def check_cache(self): + cache_path = self.cache_path + op_name = self.op_name + dim_qk = self.problem_key["dim_qk"] + dim_v = self.problem_key["dim_v"] + if os.path.exists(os.path.join(cache_path, "best_config_{}_{}_{}.json".format(op_name,dim_qk, dim_v))): + with open(os.path.join(cache_path,"best_config_{}_{}_{}.json".format(op_name,dim_qk, dim_v)),"r") as f: + best_config_dict = json.load(f) + best_config = supported_configs[best_config_dict["operation"]].from_dict(best_config_dict) + return best_config + + return None + + + def validate_shared_fuse(self, config): + return False + def validate_register_fuse(self, config): + return False + def validate_kernel(self, config): + return False + def generate_configs(self,Br:int,Bc:int,dim_qk:int,dim_v:int): + configs = [] + return configs + +if __name__=="__main__": + import torch + from configs.fwd_config import FlashFwdConfig + batch_size = 4 + seqlen = 2048 + nheads = 8 + headdim = 192 + v_headdim = 128 + device = 'cuda' + dtype = torch.bfloat16 + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads, v_headdim, device=device, dtype=dtype, + requires_grad=True) + base_tunner = BaseTunner(arch=None, torch_array=[q,k,v], op_name="flash_fwd", shape_config=ShapeConfig(headdim,v_headdim), profle_config=ProfileConfig(batch_size,seqlen,seqlen,nheads,nheads,nheads,device,dtype,0), tempdir="autotuner/temp") + + config = FlashFwdConfig(headdim,v_headdim,64,64) + base_tunner.compile([config]) + base_tunner.profile(config) \ No newline at end of file diff --git a/autotuner/code_emitter.py b/autotuner/code_emitter.py new file mode 100644 index 000000000..d595d51c8 --- /dev/null +++ b/autotuner/code_emitter.py @@ -0,0 +1,104 @@ +from configs.base_config import BaseConfig +from pathlib import Path +import os +import tempfile + +class ShapeConfig: + def __init__(self, Kd, D, is_bf16: bool=False, is_causal: bool=False) -> None: + self.Kd = Kd + self.D = D + self.is_bf16 = is_bf16 + self.is_causal = is_causal + +class ProfileConfig: + def __init__(self, batch_size, seqlen_q, seqlen_kv, nheads, nheads_k, nheads_v, device, dtype, dropout_p) -> None: + self.batch_size = batch_size + self.seqlen_q = seqlen_q + self.seqlen_kv = seqlen_kv + self.nheads = nheads + self.nheads_k = nheads_k + self.nheads_v = nheads_v + self.device = device + self.dtype = dtype + self.dropout_p = dropout_p + + +class CodeEmitter: + def __init__(self, template_dir, output_dir) -> None: + self.template_dir = template_dir + self.output_dir = output_dir + + self.profile_api_file_list = [ + "flash_fwd.cu", + "flash_profile_api.cpp", + ] + self.kernel_file_list = [ + "flash_fwd.h", + "flash_profile.h", + "flash_fwd_launch_template_profile.h" + ] + + def generate_code(self, shape_config:ShapeConfig, configs:list[BaseConfig]): + template_dir = self.template_dir + output_dir = self.output_dir + + skip_api_code = False + if not Path(output_dir).exists(): + os.mkdir(output_dir) + else: + skip_api_code = True + + # generate api code + if not skip_api_code: + for file_name in self.profile_api_file_list: + with open(Path(template_dir) / Path(file_name)) as f: + code_template = f.read() + code_template = self.emit_code_profile_api(code_template, shape_config) + + with open(Path(output_dir) / Path(file_name), "w") as f: + f.write(code_template) + + # generate kernel code + for config in configs: + kernel_code_dir = Path(output_dir) / Path(config.output_dir) + if not kernel_code_dir.exists(): + os.mkdir(kernel_code_dir) + else: + continue + + for file_name in self.kernel_file_list: + with open(Path(template_dir) / Path(file_name)) as f: + code_template = f.read() + code_template = self.emit_code_kernel(code_template, config) + + with open(kernel_code_dir / Path(file_name), "w") as f: + f.write(code_template) + + # flash_attn_profile_interface.py + with open(Path(template_dir) / Path("flash_attn_profile_interface.py")) as f: + code_template = f.read() + code_template = code_template.replace("OUTPUT_DIR", f"\"{str(output_dir)}\"") + code_template = code_template.replace("OUTPUT_KERNEL_DIR", f"\"{str(kernel_code_dir)}\"") + code_template = code_template.replace("CONFIG_NAME", f"\"{str(config)}\"") + with open(Path(kernel_code_dir) / Path("flash_attn_profile_interface.py"), "w") as f: + f.write(code_template) + + + def emit_code_kernel(self, code_template:str, config:BaseConfig): + kv = config.__dict__ + for k,v in kv.items(): + code_template = code_template.replace(f"/*{{{k}}}*/",str(v)) + return code_template + + def emit_code_profile_api(self, code_template:str, shape_config: ShapeConfig): + kv = shape_config.__dict__ + for k,v in kv.items(): + code_template = code_template.replace(f"/*{{{k}}}*/",str(v)) + return code_template + + +if __name__ == "__main__": + from configs.fwd_config import FlashFwdConfig + config = FlashFwdConfig(1,2,3,4) + ce = CodeEmitter("autotuner/template/", "autotuner/template/output/") + ce.generate_code(ShapeConfig(64,128), [config]) diff --git a/autotuner/configs/__init__.py b/autotuner/configs/__init__.py new file mode 100644 index 000000000..b5e57de67 --- /dev/null +++ b/autotuner/configs/__init__.py @@ -0,0 +1,6 @@ +from .base_config import BaseConfig +from .fwd_config import FlashFwdConfig + +supported_configs = { + "flash_fwd": FlashFwdConfig, +} \ No newline at end of file diff --git a/autotuner/configs/base_config.py b/autotuner/configs/base_config.py new file mode 100644 index 000000000..4c2e6ed1b --- /dev/null +++ b/autotuner/configs/base_config.py @@ -0,0 +1,37 @@ +class BaseConfig: + def __init__(self, Kd, D, Br, Bc, Nwarps=8) -> None: + self.Br = Br + self.Bc = Bc + self.Kd = Kd + self.D = D + self.Nwarps = Nwarps + + self.operation = None + self.template_dir = None + + def __repr__(self) -> str: + return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={})".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps) + + def __str__(self) -> str: + return f"{self.Kd}_{self.D}_{self.Br}_{self.Bc}_{self.Nwarps}" + + @classmethod + def from_dict(cls, dd:dict): + cc = cls.__new__(cls) # cls: 子类 + cc.__dict__.update(dd) + return cc + + @property + def output_dir(self): + return str(self) + +if __name__ == "__main__": + cc = BaseConfig(1,2,3,4) + print(cc) + print(repr(cc)) + print(cc.__dict__) + dd = cc.__dict__ + cc2 = BaseConfig.from_dict(dd) + print(cc2) + print(repr(cc2)) + print(cc2.__dict__) \ No newline at end of file diff --git a/autotuner/configs/fwd_config.py b/autotuner/configs/fwd_config.py new file mode 100644 index 000000000..b8f5b4d6e --- /dev/null +++ b/autotuner/configs/fwd_config.py @@ -0,0 +1,18 @@ +import os +from .base_config import BaseConfig + +class FlashFwdConfig(BaseConfig): + def __init__(self, Kd, D, Br, Bc, Nwarps=8, isQinRegs:bool = False, SharedQKSmem:bool = False) -> None: + super().__init__(Kd, D, Br, Bc, Nwarps) + + self.isQinRegs = isQinRegs or SharedQKSmem + self.SharedQKSmem = SharedQKSmem + + self.operation = "flash_fwd" + self.template_dir = os.path.join(os.path.dirname(__file__), "../../../csrc/kernels/attention") + + def __repr__(self) -> str: + return "Config(Kd={}, D={}, Br={}, Bc={}, Nwarps={}, isQinRegs={}, SharedQKSmem={})".format(self.Kd, self.D, self.Br, self.Bc, self.Nwarps, self.isQinRegs, self.SharedQKSmem) + + def __str__(self) -> str: + return f"{self.Kd}_{self.D}_{self.Br}_{self.Bc}_{self.Nwarps}_{self.isQinRegs}_{self.SharedQKSmem}" \ No newline at end of file diff --git a/autotuner/profile_attn.py b/autotuner/profile_attn.py new file mode 100644 index 000000000..f967e675c --- /dev/null +++ b/autotuner/profile_attn.py @@ -0,0 +1,29 @@ +import torch +from flash_attn.utils.benchmark import benchmark_forward + +# batch_size = 4 +# seqlen = 2048 +# nheads = 8 +# headdim = QKHeadDim +# v_headdim = VHeadDim +# device = 'cuda' +# dtype = torch.bfloat16 if is_bf16 else torch.float16 + +# dropout_p = 0.0 +# causal = is_causal +# repeats = 30 + + +def profile_fwd(fn,headdim, v_headdim, batch_size=4, seqlen=2048, nheads=8, device='cuda', is_bf16=False, causal=False, dropout_p=0.0, repeats=30): + dtype = torch.bfloat16 if is_bf16 else torch.float16 + # print(batch_size, seqlen, nheads, headdim, v_headdim, device, dtype, dropout_p, causal, repeats) + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads, v_headdim, device=device, dtype=dtype, + requires_grad=True) + f = benchmark_forward(fn, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False) + time_f = f[1].mean + # print(time_f) + return time_f \ No newline at end of file diff --git a/autotuner/template/flash_attn_profile_interface.py b/autotuner/template/flash_attn_profile_interface.py new file mode 100644 index 000000000..5cf4fa300 --- /dev/null +++ b/autotuner/template/flash_attn_profile_interface.py @@ -0,0 +1,1357 @@ +from typing import Optional, Union + +import torch +import torch.nn as nn + +import torch.utils.cpp_extension + +import os +from pathlib import Path + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + +include_path = [ + "csrc/flash_attn", + "csrc/flash_attn/src", + "csrc/cutlass/include", + OUTPUT_KERNEL_DIR, +] + +cc_flag = [] +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_75,code=sm_75") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_90,code=sm_90") + +build_dir = OUTPUT_KERNEL_DIR + "/build" +if not os.path.exists(build_dir): + os.makedirs(build_dir) + + +flash_attn_cuda = torch.utils.cpp_extension.load( + name="flash_attn_cuda"+CONFIG_NAME, + sources=[ + OUTPUT_DIR + "/flash_profile_api.cpp", # "csrc/flash_attn/flash_api.cpp", + OUTPUT_DIR + "/flash_fwd.cu", + # "csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_sm80.cu", + # "csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_sm80.cu", + # "csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_sm80.cu", + ], + extra_cflags=[ + "-O3", "-std=c++17" + ] + generator_flag, + extra_cuda_cflags=[ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", + ] + + generator_flag + + cc_flag, + extra_include_paths=include_path, + build_directory=build_dir, +) + +# isort: off +# We need to import the CUDA kernels after importing torch +# import flash_attn_2_cuda as flash_attn_cuda + +# isort: on + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +def _get_block_size_n(device, head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128 + if head_dim <= 64: + return 128 if not is_dropout else 64 + elif head_dim <= 96: + return 64 + elif head_dim <= 128: + if is_sm8x: + return 64 if (not is_dropout and is_causal) else 32 + else: + return 64 if not is_dropout else 32 + elif head_dim <= 160: + if is_sm8x: + return 64 + else: + return 32 + elif head_dim <= 192: + return 64 + elif head_dim <= 224: + return 64 + elif head_dim <= 256: + return 64 + + +def _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( + q, + k, + v, + None, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + block_table=None, + leftpad_k=None, + seqused_k=None, +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( + q, + k, + v, + None, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + deterministic, + None, + rng_state, + ) + return dq, dk, dv, softmax_d + + +def _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + deterministic, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + ctx.headdim_qk = q.shape[-1] # before padding + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension + dk = dk[..., : ctx.headdim_qk] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + block_table, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + ctx.headdim_qk = q.shape[-1] # before padding + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=block_table, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension + dk = dk[..., : ctx.headdim_qk] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # <=0.0 means deactivate + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_kvpacked_func and flash_attn_func. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedFunc.apply( + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + kv: (batch_size, seqlen, 2, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnKVPackedFunc.apply( + q, + kv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenKVPackedFunc.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + block_table, + ) + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + alibi_slopes=None, + num_splits=0, + return_softmax_lse=False, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) + block_table = maybe_contiguous(block_table) + out, softmax_lse = flash_attn_cuda.fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + None, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + num_splits, + ) + return (out, softmax_lse) if return_softmax_lse else out diff --git a/autotuner/template/flash_fwd.cu b/autotuner/template/flash_fwd.cu new file mode 100644 index 000000000..a940d81fb --- /dev/null +++ b/autotuner/template/flash_fwd.cu @@ -0,0 +1,6 @@ +#include "flash_fwd.h" + +template<> +void run_mha_fwd_, /*{Kd}*/ , /*{D}*/, /*{is_causal}*/>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim/*{Kd}*/_vdim/*{D}*/, /*{is_causal}*/>(params, stream); +} diff --git a/autotuner/template/flash_fwd.h b/autotuner/template/flash_fwd.h new file mode 100644 index 000000000..e7915205c --- /dev/null +++ b/autotuner/template/flash_fwd.h @@ -0,0 +1,18 @@ +#include "flash_fwd_launch_template_profile.h" + +#define False false +#define True true + +template +void run_mha_fwd_qkdim/*{Kd}*/_vdim/*{D}*/(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = /*{Kd}*/; + constexpr static int VHeaddim = /*{D}*/; + constexpr static int Br = /*{Br}*/; + constexpr static int Bc = /*{Bc}*/; + constexpr static int Nwarps = /*{Nwarps}*/; + constexpr static bool IsQinRegs = /*{isQinRegs}*/; + constexpr static bool SharedQKSmem = /*{SharedQKSmem}*/; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} \ No newline at end of file diff --git a/autotuner/template/flash_fwd_launch_template_profile.h b/autotuner/template/flash_fwd_launch_template_profile.h new file mode 100644 index 000000000..1bb2fa8d0 --- /dev/null +++ b/autotuner/template/flash_fwd_launch_template_profile.h @@ -0,0 +1,168 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "static_switch.h" +#include "flash_profile.h" +#include "flash_fwd_kernel.h" + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel;// TODO: Check if this is correct + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; // TODO: Check if this is correct + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kQKHeadDim % 128 == 0 ? 4 : (Kernel_traits::kQKHeadDim % 64 == 0 ? 8 : 16); // TODO: Check if this is correct + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = QKHeaddim <= 64 ? 256 : (QKHeaddim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal>(params, stream); +} diff --git a/autotuner/template/flash_profile.h b/autotuner/template/flash_profile.h new file mode 100644 index 000000000..51e04ced5 --- /dev/null +++ b/autotuner/template/flash_profile.h @@ -0,0 +1,196 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include // For at::cuda::philox::unpack + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k, h_v; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + int h_h_v_ratio; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, vd, seqlen_q_rounded, seqlen_k_rounded, d_rounded, vd_rounded, rotary_dim, total_q; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + float softcap; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +// template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/autotuner/template/flash_profile_api.cpp b/autotuner/template/flash_profile_api.cpp new file mode 100644 index 000000000..48c84a5d6 --- /dev/null +++ b/autotuner/template/flash_profile_api.cpp @@ -0,0 +1,1694 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include + +#include + +#include "flash_profile.h" +// #include "static_switch.h" +// #include "static_switch_headdim.h" +#define False false +#define True true + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t h_v, + const size_t d, + const size_t d_rounded, + const size_t vd, + const size_t vd_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_v = h_v; + params.h_h_k_ratio = h / h_k; + params.h_h_v_ratio = h / h_v; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.vd = vd; + params.vd_rounded = vd_rounded; + + // Set the different scale values. + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t h_v, + const size_t d, + const size_t d_rounded, + const size_t vd, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor out, + const at::Tensor dout, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap, + bool deterministic, + const bool unpadded_lse) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, h_v, d, d_rounded,vd, vd, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + nullptr, + nullptr, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + false, // seqlenq_ngroups_swapped + unpadded_lse); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + constexpr bool is_bf16 = /*{is_bf16}*/; + constexpr bool is_causal = /*{is_causal}*/; + constexpr int kQKHeadDim = /*{Kd}*/; + constexpr int kVHeadDim = /*{D}*/; + assert(params.is_bf16 == is_bf16); + assert(params.is_causal == is_causal); + assert(params.d == kQKHeadDim); + assert(params.vd == kVHeadDim); + + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_, kQKHeadDim, kVHeadDim, is_causal>(params, stream); + } else { + // TODO: temporary workaround + // run_mha_fwd_splitkv_dispatch, kQKHeadDim, kVHeadDim, is_causal>(params, stream); + } + +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, + const int num_heads, const int head_size, const int v_head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const int v_head_size_rounded,const float p_dropout, + const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int max_head_size = head_size > v_head_size ? head_size : v_head_size; + const int block_n = max_head_size <= 64 ? 256 : (max_head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, v_head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } + + return std::make_tuple(softmax_lse_accum, out_accum); +} + +void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ +#ifdef FLASHATTENTION_DISABLE_ALIBI + TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); + params.alibi_slopes_ptr = nullptr; +#else + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); + params.alibi_slopes_ptr = alibi_slopes.data_ptr(); + params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } else { + params.alibi_slopes_ptr = nullptr; + } +#endif +} + +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_v x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int v_head_size_og = v.sizes()[3]; + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + const int num_heads_v = v.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_maxkv; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_maxkv; + } + + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_v, v_head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + // v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + // v_padded = v; + } + if (v_head_size_og % 8 != 0) { + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], v_head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2); + } + if (v_head_size_og % 8 != 0) { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } else { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_v, + head_size, head_size_rounded, + v_head_size, v_head_size_rounded, + q_padded, k_padded, v_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap + ); + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, v_head_size, seqlen_k, seqlen_q, + head_size_rounded, v_head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_maxkv * seqlen_q, 1}); + } + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_v x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_v x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &leftpad_k_, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + const int v_head_size_og = v.sizes()[2]; + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + // TODO: check here + const int num_heads_v = paged_KV ? v.size(2) : v.size(1); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_maxkv; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og}); + max_seqlen_q = ngroups; + num_heads = num_heads_maxkv; + cu_seqlens_q_d = nullptr; + } + + const int total_q = q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_v, v_head_size_og); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_v, v_head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + // v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + // v_padded = v; + } + if (v_head_size_og % 8 != 0) { + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { + v_padded = v; + } + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], v_head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og}); + } + if (v_head_size_og % 8 != 0) { + out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } else { + out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {p.zero_();} + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_v, + head_size, head_size_rounded, + v_head_size, v_head_size_rounded, + q_padded, k_padded, v_padded, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true); + params.total_q = total_q; + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k_padded.stride(0); + params.v_batch_stride = v_padded.stride(0); + } + params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv(params, batch_size, num_heads, head_size, v_head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded,v_head_size_rounded, + p_dropout, /*num_splits*/ 0, dprops, opts); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream, paged_KV); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + if (seqlenq_ngroups_swapped) { + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_maxkv, head_size_og}; + int64_t size_after[] = {batch_size, num_heads_maxkv * max_seqlen_q, head_size_og}; + int64_t o_size_before[] = {batch_size, max_seqlen_q, num_heads_maxkv, v_head_size_og}; + int64_t o_size_after[] = {batch_size, num_heads_maxkv * max_seqlen_q, v_head_size_og}; + out = out.reshape(o_size_before).transpose(1, 2).reshape(o_size_after); + out_padded = out_padded.reshape(o_size_before).transpose(1, 2).reshape(o_size_after); + q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + + constexpr bool is_bf16 = /*{is_bf16}*/; + constexpr bool is_causal = /*{is_causal}*/; + constexpr int kQKHeadDim = /*{Kd}*/; + constexpr int kVHeadDim = /*{D}*/; + + assert(params.is_bf16 == is_bf16); + assert(params.is_causal == is_causal); + assert(params.d == kQKHeadDim); + assert(params.vd == kVHeadDim); + + // TODO: temporary workaround + // run_mha_bwd_, kQKHeadDim, kVHeadDim, is_causal>(params, stream); +} +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_v x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_v x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + if (is_causal) { window_size_right = 0; } + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int v_head_size_og = v.sizes()[3]; + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + const int num_heads_v = v.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og % 8 == 0, " v head_size should be a multiple of 8"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); + if ((head_size > 192 || v_head_size_og > 192) && is_dropout) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800"); + } + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_v, v_head_size_og); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, v_head_size_og); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_v, v_head_size_og); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + if (loop) { + if (!deterministic) { + dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } + // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + } + if (num_heads_v != num_heads) { + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, v_head_size_og}, opts); + } else { + dv_expanded = dv; + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_v, + head_size, head_size_rounded, + v_head_size_og, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + nullptr, + nullptr, + loop ? dq_accum.data_ptr() : nullptr, + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + deterministic, + /*unpadded_lse*/false); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + } + if (num_heads_v != num_heads) { + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_v, num_heads / num_heads_v, v_head_size_og}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + if (is_causal) { window_size_right = 0; } + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + const int v_head_size_og = v.sizes()[2]; + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + const int num_heads_v = v.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og % 8 == 0, " v head_size should be a multiple of 8"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); + + if ((head_size > 192 || v_head_size_og > 192) && is_dropout) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800"); + } + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_v, v_head_size_og); + CHECK_SHAPE(out, total_q, num_heads, v_head_size_og); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_v, v_head_size_og); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + if (loop) { + // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) + // because that would be too large if there is a very long sequence and the rest of the sequences are short. + // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). + // Note that 128 is the max block size on the seqlen_q dimension. + // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to + // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will + // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally + // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + // Same holds for softmax_d, since LSE is stored in unpadded format. + if (!deterministic) { + dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + } + if (num_heads_v != num_heads) { + dv_expanded = torch::empty({total_k, num_heads, v_head_size_og}, opts); + } else { + dv_expanded = dv; + } + + if( zero_tensors ) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_v, + head_size, head_size_rounded, + v_head_size_og, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + loop ? dq_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + deterministic, + /*unpadded_lse*/true); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.total_q = total_q; + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + } + if (num_heads_v != num_heads) { + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_v, num_heads / num_heads_v, v_head_size_og}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} + +std::vector +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &seqlens_k_, // batch_size + c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &leftpad_k_, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + int num_splits + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + const auto sizes = q.sizes(); + const int v_head_size_og = vcache.sizes()[3]; + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : kcache.size(0); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int num_heads_k = kcache.size(2); + const int num_heads_v = vcache.size(2); + const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_maxkv; + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_maxkv; + } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + if (!paged_KV) { + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_v, v_head_size_og); + } else { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_v, v_head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + at::Tensor q_padded, kcache_padded, vcache_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { + q_padded = q; + kcache_padded = kcache; + vcache_padded = vcache; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + // TODO: check here for seqlenq_ngroups_swapped + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, v_head_size_og); + if (v_head_size_og % 8 != 0) { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } else { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, num_heads_v, + head_size, head_size_rounded, + v_head_size, v_head_size_rounded, + q_padded, kcache_padded, vcache_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + softcap + ); + + at::Tensor k, v, k_padded, v_padded; + if (k_.has_value()) { + TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); + k = k_.value(); + v = v_.value(); + TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); + TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); + CHECK_DEVICE(k); CHECK_DEVICE(v); + TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); + int seqlen_knew = k.size(1); + CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_v, v_head_size_og); + if (head_size_og % 8 != 0) { + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { + k_padded = k; + v_padded = v; + } + params.seqlen_knew = seqlen_knew; + params.knew_ptr = k_padded.data_ptr(); + params.vnew_ptr = v_padded.data_ptr(); + // All stride are in elements, not bytes. + params.knew_batch_stride = k_padded.stride(0); + params.vnew_batch_stride = v_padded.stride(0); + params.knew_row_stride = k_padded.stride(-3); + params.vnew_row_stride = v_padded.stride(-3); + params.knew_head_stride = k_padded.stride(-2); + params.vnew_head_stride = v_padded.stride(-2); + } + + if (seqlens_k_.has_value()) { + auto seqlens_k = seqlens_k_.value(); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + params.cu_seqlens_k = static_cast(seqlens_k.data_ptr()); + } + params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_cos); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_sin); + TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + } else { + params.rotary_dim = 0; + } + + if (cache_batch_idx_.has_value()) { + auto cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + CHECK_CONTIGUOUS(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); + } + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, v_head_size, seqlen_k, seqlen_q, + head_size_rounded, v_head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts); + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + } + params.page_block_size = page_block_size; + + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, + // or paged KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); + + if (head_size_og % 8 != 0) { + // out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // if (out_.has_value()) { out_.value().copy_(out); } + if (k_.has_value()) { + // It's expensive to copy the KV cache here for the case where head size not divisible by 8, + // but we don't expect to get this case in practice. This is just so that the code works for that case. + kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + // vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + } + } + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + if (k_.has_value()) { + // It's expensive to copy the KV cache here for the case where head size not divisible by 8, + // but we don't expect to get this case in practice. This is just so that the code works for that case. + // kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)})); + } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_maxkv * seqlen_q, 1}); + } + return {out, softmax_lse}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); + // m.def("bwd", &mha_bwd, "Backward pass"); + // m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); + // m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); +} diff --git a/autotuner/test_run_tunner.py b/autotuner/test_run_tunner.py new file mode 100644 index 000000000..61af4456d --- /dev/null +++ b/autotuner/test_run_tunner.py @@ -0,0 +1,23 @@ +import torch +from tunner import FlashFwdTunner +from arch import A100 +from code_emitter import ShapeConfig,ProfileConfig + +batch_size = 4 +seqlen = 2048 +nheads = 8 +headdim = 128# 192 +v_headdim = 256# 128 +device = 'cuda:0' +dtype = torch.bfloat16 +dropout_p = 0.0 # 0.0 + +q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) +k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) +v = torch.randn(batch_size, seqlen, nheads, v_headdim, device=device, dtype=dtype, + requires_grad=True) + +tunner = FlashFwdTunner(A100(), [q,k,v], ShapeConfig(headdim,v_headdim), ProfileConfig(batch_size,seqlen,seqlen,nheads,nheads,nheads,device,dtype,dropout_p), "autotuner/temp128_256") # "autotuner/temp192_128" +tunner.tune() diff --git a/autotuner/tunner.py b/autotuner/tunner.py new file mode 100644 index 000000000..4b701996a --- /dev/null +++ b/autotuner/tunner.py @@ -0,0 +1,66 @@ + +import ctypes +import os +import torch + +from base_tunner import BaseTunner +from configs.fwd_config import FlashFwdConfig + +class FlashFwdTunner(BaseTunner): + def __init__(self, arch, torch_array: list, shape_config, profile_config, tempdir: str): + super().__init__(arch, torch_array, "flash_fwd", shape_config, profile_config, tempdir) + + def validate_register_fuse(self, config): + Br = config.Br + Bc = config.Bc + Kd = config.Kd + D = config.D + Nthreads = config.Nwarps * 32 + mmam, mman, mmak = self.arch.cutlass_mma + belem_per_thread = mman*mmak/self.arch.warp_size + + # check tile size + if Br % (mmam*Nthreads/self.arch.warp_size) != 0: + return False + # check shared memory + smem_size_q = config.Br * config.Kd * 2 + smem_size_k = config.Bc * config.Kd * 2 + smem_size_qk = smem_size_q + smem_size_k + smem_size_v = config.Bc * config.D * 2 + smem_out = config.Br * config.D * 2 + if config.SharedQKSmem: + smem_size = max(smem_size_q, smem_size_k+smem_size_v) + else: + smem_size = smem_size_qk + smem_size_v + smem_size = max(smem_size, smem_out) + if smem_size > self.arch.smem_cap: + return False + # check register + reg_used_accum = (Br * D * 4 + Br*Bc*4)/(Nthreads * 4) + reg_used_matmul2 = (Br * D * 4 + Br*Bc*2)/(Nthreads * 4) + (D/(mman*1) * belem_per_thread*2) / 4 + reg_used_matmul1 = (Br * D * 4 + Br * Bc * 4)/(Nthreads * 4) + (Bc/(mman*1) * belem_per_thread*2) / 4 + reg_used_qinregs = (Br * Kd * 2)/(Nthreads * 4) + if config.isQinRegs: + reg_used = reg_used_accum + reg_used_qinregs + else: + reg_used = reg_used_accum # max(reg_used_accum, reg_used_matmul2, reg_used_matmul1) + if reg_used > min(self.arch.register_per_thread, self.arch.reg_cap/Nthreads): + return False + return True + + def generate_configs(self,Br:int,Bc:int,dim_qk:int,dim_v:int): + configs = [] + for Nthreads in [128, 256]: + # TODO: more general + # global load atom + load_atom = 64 if (dim_qk % 64 == 0 and dim_v % 64 == 0 ) else 32 + NthreadsPerRow = load_atom / (128/16) + if Br % (Nthreads / NthreadsPerRow) != 0 or Bc % (Nthreads / NthreadsPerRow) != 0: + continue + config1 = FlashFwdConfig(dim_qk,dim_v,Br,Bc,Nthreads//32,False,False) + config2 = FlashFwdConfig(dim_qk,dim_v,Br,Bc,Nthreads//32,True,False) + config3 = FlashFwdConfig(dim_qk,dim_v,Br,Bc,Nthreads//32,True,True) + configs.append(config1) + configs.append(config2) + configs.append(config3) + return configs diff --git a/autotunner.md b/autotunner.md new file mode 100644 index 000000000..165b16d99 --- /dev/null +++ b/autotunner.md @@ -0,0 +1,20 @@ +# Autotuner + +Autotuner can automatically generate the best config for flash-attention kernel with not-implemented headdim qk & headdim v , or existing headdim on different hardware such as nvidia Ampere, Ada Lovelace. + +Currently, the autotuner only support flash attention forward. We plan to support backward and forward_split soon. + +## Usage + +Currently, you need to first install flashattn from source. Then, you can run the autotuner with head-dimensions of qk and v you want to tune. After that, you need to modify/create `csrc/flash_attn/src/flash_fwd_qkdim*_vdim*_sm80.h` with the tuned config. Finally, you need to rebuild the flashattn from source. + + + +The detailed steps are as follows: + +- Install flashattn from source +- run ```python autotuner/test_run_tunner.py ``` with problem size you want to tune. +- If the headdim already exists in `csrc/flash_attn/src`, you need to modify `csrc/flash_attn/src/flash_fwd_qkdim*_vdim*_sm80.h` with the tuned best config. If the headdim does not exist, you need to create `csrc/flash_attn/src/flash_fwd_qkdim*_vdim*_sm80.h`, `csrc/flash_attn/src/flash_bwd_qkdim*_vdim*_sm80.h` with the tuned best config and the corresponding `.cu` files; After that, you need to add the headdim in `headdim.json`. +- Rebuild the flashattn from source. + + diff --git a/benchmarks/benchmark_head_headdim.py b/benchmarks/benchmark_head_headdim.py new file mode 100644 index 000000000..b254aff26 --- /dev/null +++ b/benchmarks/benchmark_head_headdim.py @@ -0,0 +1,208 @@ +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import csv +import pickle +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + +try: + from triton.ops.flash_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + + +def flops(batch, seqlen, headdim, v_headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 2 * batch * seqlen**2 * nheads * (headdim+v_headdim) // (2 if causal else 1) + b = 2 * batch * seqlen**2 * nheads * (3*headdim+2*v_headdim) // (2 if causal else 1) + return f if mode == "fwd" else (b if mode == "bwd" else f+b) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def attention_pytorch(q, k, v, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, nheads, d = q.shape + nheads_k = k.shape[2] + nheads_v = v.shape[2] + if nheads_k < nheads: + k = repeat(k, 'b s h d -> b s (h g) d', g=nheads//nheads_k) + if nheads_v < nheads: + v = repeat(v, 'b s h d -> b s (h g) d', g=nheads//nheads_v) + v_d = v.shape[-1] + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=q.dtype) + + +def flash_attention_pad(q,k,v, dropout_p=0.0, causal=True): + batch_size, seqlen, nheads, d = q.shape + nheads_k = k.shape[2] + nheads_v = v.shape[2] + if nheads_k < nheads_v: + k = repeat(k, 'b s h d -> b s (h g) d', g=nheads_v//nheads_k) + elif nheads_k > nheads_v: + v = repeat(v, 'b s h d -> b s (h g) d', g=nheads_k//nheads_v) + v_d = v.shape[-1] + if d == v_d: + return flash_attn_func(q, k, v, dropout_p, causal) + if d < v_d: + q = F.pad(q, (0, v_d-d)) + k = F.pad(k, (0, v_d-d)) + return flash_attn_func(q, k, v, dropout_p, causal) + elif d > v_d: + v = F.pad(v, (0, d-v_d)) + o = flash_attn_func(q, k, v, dropout_p, causal) + return o[:,:,:,:v_d] + + + +def time_fwd_bwd(func, *args, **kwargs): + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + +save_csv = True + +repeats = 30 +device = 'cuda' +dtype = torch.float16 +torch.cuda.set_device(0) + +bs_seqlen_vals = [(4, 512), (4, 1024), (4, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [ (32,64),(64,128)] +nheads_qkv = (32, 4, 16) +dropout_p = 0.0 + +methods = (["CustomFlash2", "Pytorch", "Flash2_Pad"]) + +if save_csv: + csvfile = open('flash2_attn_time.csv', 'w', newline='') + writer = csv.writer(csvfile) + writer.writerow([ + "causal", "qk_headdim", "v_headdim","nheads_q", "nheads_k", "nheads_v", "batch_size", "seqlen", + "time_fwd_CustomFlash2", "time_bwd_CustomFlash2", "time_fwd_bwd_CustomFlash2", + "time_fwd_Pytorch", "time_bwd_Pytorch", "time_fwd_bwd_Pytorch", + "time_fwd_Flash2_Pad", "time_bwd_Flash2_Pad", "time_fwd_bwd_Flash2_Pad", + "flops_fwd_CustomFlash2", "flops_bwd_CustomFlash2", "flops_fwd_bwd_CustomFlash2", + "flops_fwd_Pytorch", "flops_bwd_Pytorch", "flops_fwd_bwd_Pytorch", + "flops_fwd_Flash2_Pad", "flops_bwd_Flash2_Pad", "flops_fwd_bwd_Flash2_Pad", + ]) + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim,v_headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads_q, nheads_k, nheads_v = nheads_qkv + q = torch.randn(batch_size, seqlen, nheads_q, headdim, device=device, dtype=dtype, + requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads_k, headdim, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_v, v_headdim, device=device, dtype=dtype, + requires_grad=True) + f, b = time_fwd_bwd( + flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "CustomFlash2"] = f + time_b[config, "CustomFlash2"] = b + + try: + q = q.detach().requires_grad_(True) + k = k.detach().requires_grad_(True) + v = v.detach().requires_grad_(True) + f, b = time_fwd_bwd( + attention_pytorch, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + q = q.detach().requires_grad_(True) + k = k.detach().requires_grad_(True) + v = v.detach().requires_grad_(True) + f, b = time_fwd_bwd( + flash_attention_pad, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash2_Pad"] = f + time_b[config, "Flash2_Pad"] = b + + print(f"### causal={causal}, qk_headdim={headdim}, v_headdim={v_headdim}, batch_size={batch_size}, seqlen={seqlen}, head_qkv={nheads_qkv} ###") + for method in methods: + time_f_b[config, method] = time_f[config, method] + time_b[config, method] + speed_f[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads_q, causal, mode="fwd"), + time_f[config, method] + ) + speed_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads_q, causal, mode="bwd"), + time_b[config, method] + ) + speed_f_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads_q, causal, mode="fwd_bwd"), + time_f_b[config, method] + ) + print( + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " + f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " + f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" + ) + if save_csv: + writer.writerow([ + causal, headdim, v_headdim, *nheads_qkv, batch_size, seqlen, + time_f[config, "CustomFlash2"], time_b[config, "CustomFlash2"], time_f_b[config, "CustomFlash2"], + time_f[config, "Pytorch"], time_b[config, "Pytorch"], time_f_b[config, "Pytorch"], + time_f[config, "Flash2_Pad"], time_b[config, "Flash2_Pad"], time_f_b[config, "Flash2_Pad"], + speed_f[config, "CustomFlash2"], speed_b[config, "CustomFlash2"], speed_f_b[config, "CustomFlash2"], + speed_f[config, "Pytorch"], speed_b[config, "Pytorch"], speed_f_b[config, "Pytorch"], + speed_f[config, "Flash2_Pad"], speed_b[config, "Flash2_Pad"], speed_f_b[config, "Flash2_Pad"], + ]) + +if save_csv: + csvfile.close() + + + +# with open('flash2_attn_time.plk', 'wb') as fp: +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/benchmarks/benchmark_headdim.py b/benchmarks/benchmark_headdim.py new file mode 100644 index 000000000..5e5ceb2f3 --- /dev/null +++ b/benchmarks/benchmark_headdim.py @@ -0,0 +1,196 @@ +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import csv +import pickle +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + +try: + from triton.ops.flash_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + + +def flops(batch, seqlen, headdim, v_headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 2 * batch * seqlen**2 * nheads * (headdim+v_headdim) // (2 if causal else 1) + b = 2 * batch * seqlen**2 * nheads * (3*headdim+2*v_headdim) // (2 if causal else 1) + return f if mode == "fwd" else (b if mode == "bwd" else f+b) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def attention_pytorch(q, k, v, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, nheads, d = q.shape + v_d = v.shape[-1] + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=q.dtype) + + +def flash_attention_pad(q,k,v, dropout_p=0.0, causal=True): + batch_size, seqlen, nheads, d = q.shape + v_d = v.shape[-1] + if d == v_d: + return flash_attn_func(q, k, v, dropout_p, causal) + if d < v_d: + q = F.pad(q, (0, v_d-d)) + k = F.pad(k, (0, v_d-d)) + return flash_attn_func(q, k, v, dropout_p, causal) + elif d > v_d: + v = F.pad(v, (0, d-v_d)) + o = flash_attn_func(q, k, v, dropout_p, causal) + return o[:,:,:,:v_d] + + + +def time_fwd_bwd(func, *args, **kwargs): + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + +save_csv = True + +repeats = 30 +device = 'cuda' +dtype = torch.float16 +# torch.cuda.set_device(5) + +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [(32,64),(64,128),(96,192), (128,256)] +dim = 2048 # qk dim +dropout_p = 0.0 + +methods = (["CustomFlash2", "Pytorch", "Flash2_Pad"]) + +if save_csv: + csvfile = open('flash2_attn_time.csv', 'w', newline='') + writer = csv.writer(csvfile) + writer.writerow([ + "causal", "qk_headdim", "v_headdim", "batch_size", "seqlen", + "time_fwd_CustomFlash2", "time_bwd_CustomFlash2", "time_fwd_bwd_CustomFlash2", + "time_fwd_Pytorch", "time_bwd_Pytorch", "time_fwd_bwd_Pytorch", + "time_fwd_Flash2_Pad", "time_bwd_Flash2_Pad", "time_fwd_bwd_Flash2_Pad", + "flops_fwd_CustomFlash2", "flops_bwd_CustomFlash2", "flops_fwd_bwd_CustomFlash2", + "flops_fwd_Pytorch", "flops_bwd_Pytorch", "flops_fwd_bwd_Pytorch", + "flops_fwd_Flash2_Pad", "flops_bwd_Flash2_Pad", "flops_fwd_bwd_Flash2_Pad", + ]) + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim,v_headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads, v_headdim, device=device, dtype=dtype, + requires_grad=True) + f, b = time_fwd_bwd( + flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "CustomFlash2"] = f + time_b[config, "CustomFlash2"] = b + + try: + q = q.detach().requires_grad_(True) + k = k.detach().requires_grad_(True) + v = v.detach().requires_grad_(True) + f, b = time_fwd_bwd( + attention_pytorch, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + q = q.detach().requires_grad_(True) + k = k.detach().requires_grad_(True) + v = v.detach().requires_grad_(True) + f, b = time_fwd_bwd( + flash_attention_pad, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash2_Pad"] = f + time_b[config, "Flash2_Pad"] = b + + print(f"### causal={causal}, qk_headdim={headdim}, v_headdim={v_headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + for method in methods: + time_f_b[config, method] = time_f[config, method] + time_b[config, method] + speed_f[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads, causal, mode="fwd"), + time_f[config, method] + ) + speed_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads, causal, mode="bwd"), + time_b[config, method] + ) + speed_f_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, v_headdim, nheads, causal, mode="fwd_bwd"), + time_f_b[config, method] + ) + print( + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " + f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " + f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" + ) + if save_csv: + writer.writerow([ + causal, headdim, v_headdim, batch_size, seqlen, + time_f[config, "CustomFlash2"], time_b[config, "CustomFlash2"], time_f_b[config, "CustomFlash2"], + time_f[config, "Pytorch"], time_b[config, "Pytorch"], time_f_b[config, "Pytorch"], + time_f[config, "Flash2_Pad"], time_b[config, "Flash2_Pad"], time_f_b[config, "Flash2_Pad"], + speed_f[config, "CustomFlash2"], speed_b[config, "CustomFlash2"], speed_f_b[config, "CustomFlash2"], + speed_f[config, "Pytorch"], speed_b[config, "Pytorch"], speed_f_b[config, "Pytorch"], + speed_f[config, "Flash2_Pad"], speed_b[config, "Flash2_Pad"], speed_f_b[config, "Flash2_Pad"], + ]) + +if save_csv: + csvfile.close() + + + +# with open('flash2_attn_time.plk', 'wb') as fp: +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/compute_sm.py b/compute_sm.py new file mode 100644 index 000000000..f65cf3bc6 --- /dev/null +++ b/compute_sm.py @@ -0,0 +1,9 @@ +Br = 128 +Bc = 64 +QKHeaddim = 128 +VHeaddim = 256 +bwdsmem =2 *(Br * QKHeaddim * 2 + Br * VHeaddim + Bc * QKHeaddim + Bc * VHeaddim + Br * Bc * 2) +bwdsmem = bwdsmem/1024 +fwdsmem = (Br * QKHeaddim + Bc * QKHeaddim + Bc * VHeaddim)*2 +fwdsmem = fwdsmem/1024 +print("fwdsmem:", fwdsmem) diff --git a/csrc/cutlass b/csrc/cutlass index 756c351b4..f7b19de32 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b +Subproject commit f7b19de32c5d1f3cedfc735c2849f12b537522ee diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index a928ec1ec..8c639a4c0 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -12,6 +12,7 @@ #include "flash.h" #include "static_switch.h" +#include "static_switch_headdim.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -27,8 +28,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_v, const size_t d, const size_t d_rounded, + const size_t vd, + const size_t vd_rounded, // device pointers const at::Tensor q, const at::Tensor k, @@ -92,13 +96,17 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.b = b; params.h = h; params.h_k = h_k; + params.h_v = h_v; params.h_h_k_ratio = h / h_k; + params.h_h_v_ratio = h / h_v; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; + params.vd = vd; + params.vd_rounded = vd_rounded; // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP @@ -162,8 +170,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_v, const size_t d, const size_t d_rounded, + const size_t vd, // device pointers const at::Tensor q, const at::Tensor k, @@ -189,7 +199,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, const bool unpadded_lse) { set_params_fprop(params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, h_v, d, d_rounded,vd, vd, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, @@ -237,14 +247,14 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { - HEADDIM_SWITCH(params.d, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } - }); + QKHEADDIM_VHEADDIM_SWITCH(params.d, params.vd, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); }); }); } @@ -292,12 +302,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n } std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, - const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, - const int head_size_rounded, const float p_dropout, + const int num_heads, const int head_size, const int v_head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const int v_head_size_rounded,const float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int max_head_size = head_size > v_head_size ? head_size : v_head_size; + const int block_n = max_head_size <= 64 ? 256 : (max_head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -313,7 +324,7 @@ std::tuple set_params_splitkv(Flash_fwd_params ¶ms, } if (params.num_splits > 1) { softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, v_head_size_rounded}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } @@ -345,7 +356,7 @@ void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_v x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, @@ -366,6 +377,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == torch::kBFloat16) { @@ -381,16 +393,19 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); const auto sizes = q.sizes(); - + const int v_head_size_og = v.sizes()[3]; const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + const int num_heads_v = v.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } @@ -403,26 +418,33 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); - const int ngroups = num_heads / num_heads_k; + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_maxkv; if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; - num_heads = num_heads_k; + num_heads = num_heads_maxkv; } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_v, v_head_size_og); at::Tensor q_padded, k_padded, v_padded; if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + // v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; k_padded = k; + // v_padded = v; + } + if (v_head_size_og % 8 != 0) { + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { v_padded = v; } @@ -432,18 +454,26 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], v_head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2); + } + if (v_head_size_og % 8 != 0) { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { - out = torch::empty_like(q_padded); + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -466,8 +496,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_v, head_size, head_size_rounded, + v_head_size, v_head_size_rounded, q_padded, k_padded, v_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, @@ -484,8 +515,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( - params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); + params, batch_size, num_heads, head_size, v_head_size, seqlen_k, seqlen_q, + head_size_rounded, v_head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -516,16 +547,16 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size } at::Tensor out_padded = out; - if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); if (out_.has_value()) { out_.value().copy_(out); } } if (seqlenq_ngroups_swapped) { - out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); - out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); - q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_maxkv * seqlen_q, 1}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } @@ -533,7 +564,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_v x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_v x head_size if there's a block_table. c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -592,11 +623,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); - + const int v_head_size_og = v.sizes()[2]; const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size_og = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + // TODO: check here + const int num_heads_v = paged_KV ? v.size(2) : v.size(1); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } @@ -612,12 +645,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); - const int ngroups = num_heads / num_heads_k; + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_maxkv; if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og}); max_seqlen_q = ngroups; - num_heads = num_heads_k; + num_heads = num_heads_maxkv; cu_seqlens_q_d = nullptr; } @@ -625,7 +659,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } @@ -634,10 +670,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (!paged_KV) { const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_v, v_head_size_og); } else { CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_v, v_head_size_og); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } @@ -655,31 +691,43 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + // v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; k_padded = k; + // v_padded = v; + } + if (v_head_size_og % 8 != 0) { + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } else { v_padded = v; } - at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + CHECK_SHAPE(out, sizes[0], sizes[1], v_head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); + out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og}); + } + if (v_head_size_og % 8 != 0) { + out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { - out = torch::empty_like(q_padded); + out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -707,8 +755,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_v, head_size, head_size_rounded, + v_head_size, v_head_size_rounded, q_padded, k_padded, v_padded, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), @@ -736,8 +785,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (seqlenq_ngroups_swapped) { // Only apply split-k for decoding std::tie(softmax_lse_accum, out_accum) = - set_params_splitkv(params, batch_size, num_heads, head_size, - max_seqlen_k, max_seqlen_q, head_size_rounded, + set_params_splitkv(params, batch_size, num_heads, head_size, v_head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded,v_head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); } @@ -780,16 +829,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } at::Tensor out_padded = out; - if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); if (out_.has_value()) { out_.value().copy_(out); } } if (seqlenq_ngroups_swapped) { - int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; - int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; - out = out.reshape(size_before).transpose(1, 2).reshape(size_after); - out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_maxkv, head_size_og}; + int64_t size_after[] = {batch_size, num_heads_maxkv * max_seqlen_q, head_size_og}; + int64_t o_size_before[] = {batch_size, max_seqlen_q, num_heads_maxkv, v_head_size_og}; + int64_t o_size_after[] = {batch_size, num_heads_maxkv * max_seqlen_q, v_head_size_og}; + out = out.reshape(o_size_before).transpose(1, 2).reshape(o_size_after); + out_padded = out_padded.reshape(o_size_before).transpose(1, 2).reshape(o_size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } @@ -799,24 +850,23 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { - HEADDIM_SWITCH(params.d, [&] { + QKHEADDIM_VHEADDIM_SWITCH(params.d, params.vd, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_mha_bwd_(params, stream); + run_mha_bwd_(params, stream); }); }); }); } - std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_v x head_size const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_v x head_size c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, @@ -865,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); const auto sizes = q.sizes(); - + const int v_head_size_og = v.sizes()[3]; const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; @@ -873,20 +923,24 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + const int num_heads_v = v.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192 && is_dropout) { + TORCH_CHECK(v_head_size_og % 8 == 0, " v head_size should be a multiple of 8"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); + if ((head_size > 192 || v_head_size_og > 192) && is_dropout) { TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + // TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } if (window_size_left >= seqlen_k) { window_size_left = -1; } @@ -894,8 +948,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_v, v_head_size_og); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, v_head_size_og); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); at::Tensor dq, dk, dv; @@ -922,7 +976,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_v, v_head_size_og); } else { dv = torch::empty_like(v); } @@ -960,9 +1014,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); } else { dk_expanded = dk; + } + if (num_heads_v != num_heads) { + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, v_head_size_og}, opts); + } else { dv_expanded = dv; } @@ -972,8 +1029,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_v, head_size, head_size_rounded, + v_head_size_og, q, k, v, out, dout_padded, dq, dk_expanded, dv_expanded, nullptr, @@ -1027,12 +1085,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + } + if (num_heads_v != num_heads) { + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_v, num_heads / num_heads_v, v_head_size_og}), {3}); } if (head_size_og % 8 != 0) { dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); } return { dq, dk, dv, softmax_d }; @@ -1106,7 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); - + const int v_head_size_og = v.sizes()[2]; const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; const int num_heads = sizes[1]; @@ -1114,13 +1174,18 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const int head_size = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); + const int num_heads_v = v.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192 && is_dropout) { + TORCH_CHECK(v_head_size_og % 8 == 0, " v head_size should be a multiple of 8"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); + + if ((head_size > 192 || v_head_size_og > 192) && is_dropout) { TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -1135,8 +1200,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(v, total_k, num_heads_v, v_head_size_og); + CHECK_SHAPE(out, total_q, num_heads, v_head_size_og); CHECK_SHAPE(dout, total_q, num_heads, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); @@ -1165,7 +1230,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + CHECK_SHAPE(dv, total_k, num_heads_v, v_head_size_og); } else { dv = torch::empty_like(v); } @@ -1209,9 +1274,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); } else { dk_expanded = dk; + } + if (num_heads_v != num_heads) { + dv_expanded = torch::empty({total_k, num_heads, v_head_size_og}, opts); + } else { dv_expanded = dv; } @@ -1228,8 +1296,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_v, head_size, head_size_rounded, + v_head_size_og, q, k, v, out, dout_padded, dq, dk_expanded, dv_expanded, cu_seqlens_q.data_ptr(), @@ -1282,12 +1351,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + } + if (num_heads_v != num_heads) { + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_v, num_heads / num_heads_v, v_head_size_og}), {2}); } if (head_size_og % 8 != 0) { dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); } return { dq, dk, dv, softmax_d }; @@ -1350,7 +1421,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } const auto sizes = q.sizes(); - + const int v_head_size_og = vcache.sizes()[3]; const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; @@ -1362,10 +1433,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); + const int num_heads_v = vcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_v == 0, "Number of heads in value must divide number of heads in query"); + TORCH_CHECK(v_head_size_og <= 256, "FlashAttention backward only supports head dimension at most 256"); // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } @@ -1373,12 +1447,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int num_heads_maxkv = num_heads_k > num_heads_v ? num_heads_k : num_heads_v; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_maxkv && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && v_head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; - q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + const int ngroups = num_heads / num_heads_maxkv; + q = q.reshape({batch_size, num_heads_maxkv, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; - num_heads = num_heads_k; + num_heads = num_heads_maxkv; } if (window_size_left >= seqlen_k) { window_size_left = -1; } @@ -1387,10 +1462,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); if (!paged_KV) { CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_v, v_head_size_og); } else { CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_v, v_head_size_og); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } @@ -1398,7 +1473,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } else { q_padded = q; kcache_padded = kcache; @@ -1411,15 +1486,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + // TODO: check here for seqlenq_ngroups_swapped + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, v_head_size_og); + if (v_head_size_og % 8 != 0) { + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } } else { - out = torch::empty_like(q_padded); + out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + if (v_head_size_og % 8 != 0) { + out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); + } } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int v_head_size = round_multiple(v_head_size_og, 8); + const int v_head_size_rounded = v_head_size <= 192 ? round_multiple(v_head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -1436,8 +1520,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_v, head_size, head_size_rounded, + v_head_size, v_head_size_rounded, q_padded, kcache_padded, vcache_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, @@ -1465,10 +1550,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); int seqlen_knew = k.size(1); CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_v, v_head_size_og); if (head_size_og % 8 != 0) { k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } else { k_padded = k; v_padded = v; @@ -1541,8 +1626,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( - params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts); + params, batch_size, num_heads, head_size, v_head_size, seqlen_k, seqlen_q, + head_size_rounded, v_head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts); if (paged_KV) { params.block_table = block_table.data_ptr(); @@ -1559,19 +1644,29 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - if (out_.has_value()) { out_.value().copy_(out); } + // out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // if (out_.has_value()) { out_.value().copy_(out); } if (k_.has_value()) { // It's expensive to copy the KV cache here for the case where head size not divisible by 8, // but we don't expect to get this case in practice. This is just so that the code works for that case. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); - vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + // vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + } + } + if (v_head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + if (k_.has_value()) { + // It's expensive to copy the KV cache here for the case where head size not divisible by 8, + // but we don't expect to get this case in practice. This is just so that the code works for that case. + // kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, v_head_size_og)})); } } if (seqlenq_ngroups_swapped) { - out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_maxkv * seqlen_q, v_head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_maxkv * seqlen_q, 1}); } return {out, softmax_lse}; } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 6f597fbee..fd8dea467 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -40,10 +40,11 @@ struct Qkv_params { index_t v_head_stride; // The number of heads. - int h, h_k; + int h, h_k, h_v; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). int h_h_k_ratio; // precompute h / h_k, + int h_h_v_ratio; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -67,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + int b, seqlen_q, seqlen_k, seqlen_knew, d, vd, seqlen_q_rounded, seqlen_k_rounded, d_rounded, vd_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; @@ -189,7 +190,7 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu index 13132e86d..3597cd8fc 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu index 85a5dc88e..a2155d523 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu index 5d27cd97b..ee32c0aa9 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu index 2d7ddf46b..968f07ac0 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu index c18a78c76..7ee4d45f2 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu index 1b6173725..e3697365b 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu index a511162dc..5bdea8f5a 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu index c9ce19acb..0194aa487 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu index f492a7171..f55649e53 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu index 2df58daa2..8758a0f00 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu index 69cad5ae4..a9cb850de 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu index 3d4cab58b..66e7029af 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu deleted file mode 100644 index b2b58e2ab..000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); -} diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu deleted file mode 100644 index e65cdaede..000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); -} diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu index 692744597..972b24972 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu index d718ec88b..632b15b8f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu index 551c695e0..16edaf9ae 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu index a58770026..4aaa83cb6 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu index 1282939a0..cef067c57 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu index d6d403638..6e723a55f 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu index 60aa2d60b..87460d632 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu index b06d50eaa..439489c14 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu index 52b93be9d..af11800ff 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu index 09d9e2b75..b2fc12156 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu index 5a4ea5f46..b2d08eaa1 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu index fb115ff76..d479f07ac 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu index 5f4c26a47..01c74893e 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu index 224213d79..4b17006a6 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu index d0349014f..68c299d29 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu index 663fc8592..75d6d7822 100644 --- a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 4f95bd34a..1db7a6dc1 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -91,7 +91,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; @@ -109,7 +110,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) - + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_v_ratio) * params.v_head_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) @@ -125,25 +126,25 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); @@ -151,16 +152,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Shape>{}, Stride<_1>{}); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQdO{}); - Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + typename Kernel_traits::SmemLayoutQ{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQtransposedNoSwizzle{}); // Double buffer for sQ - Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); - Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutdOtransposed{}); Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), - typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + typename Kernel_traits::SmemLayoutdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutV{}); Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), @@ -229,8 +230,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K // // Copy Atom retiling @@ -289,20 +290,25 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = make_identity_tensor(make_shape(size<0>(sV), size<1>(sV))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); - Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); + Tensor tKcK = gmem_thr_copy_QKV.partition_D(cK); + Tensor tVcV = gmem_thr_copy_QKV.partition_D(cV); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.vd; } } // Prologue @@ -333,10 +339,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); @@ -346,17 +352,22 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdVrdV = make_tensor(shape(tdVgdV)); clear(tdKrdK); clear(tdVrdV); - Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + Tensor cdK = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cdV = make_identity_tensor(make_shape(size<0>(gdV), size<1>(gdV))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKcdK = gmem_thr_copy_dKV.partition_D(cdK); + Tensor tdVcdV = gmem_thr_copy_dKV.partition_D(cdV); + Tensor tdKpdK = make_tensor(make_shape(size<2>(tdKgdK))); + Tensor tdVpdV = make_tensor(make_shape(size<2>(tdVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKpdK); ++k) { tdKpdK(k) = get<1>(tdKcdK(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tdVpdV); ++k) { tdVpdV(k) = get<1>(tdVcdV(0, 0, k)) < params.vd; } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN ); return; } @@ -372,7 +383,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (Kernel_traits::Is_V_in_regs) { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tVgV, tVsV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN ); flash::cp_async_fence(); } @@ -418,11 +429,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // // if (cute::thread(1, 0)) { print(tKrK); } flash::copy( - gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tKgK, tKsK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN ); if (!Kernel_traits::Is_V_in_regs) { flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tVgV, tVsV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN ); } flash::cp_async_fence(); @@ -592,7 +603,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } // if (cute::thread0()) { print(dS); } - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); if (Is_first || Seq_parallel) { clear(acc_dq); @@ -708,7 +719,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); #pragma unroll for (int m = 0; m < size<1>(tdQgdQ); ++m) { @@ -733,8 +744,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor rdK = flash::convert_type(acc_dk); Tensor rdV = flash::convert_type(acc_dv); - Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdK{}); // (SMEM_N, SMEM_K) + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdV{}); // (SMEM_N, SMEM_K) // Partition sdV and sdK to match the accumulator partitioning auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); @@ -758,10 +769,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -780,7 +791,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.vd; } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 727d87e93..a1568285f 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -87,7 +87,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not // a multiple of kBlockN, we'll need to apply mask in the loop. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool is_even_K = (params.d == Kernel_traits::kQKHeadDim && params.vd == Kernel_traits::kVHeadDim);//TODO check if this is correct constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { @@ -98,7 +98,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel;//TODO check if this is correct // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -142,12 +142,12 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } } else { // 96 KB - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -166,35 +166,35 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); // This has a lot of register spilling - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); } else { // if (params.h == params.h_k) { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // } else { // } } }); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); } template @@ -212,13 +212,13 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { // 116 KB // This is faster for dropout since we don't have many registers to spare - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -236,24 +236,24 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); } else { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); }); } @@ -270,9 +270,9 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -290,9 +290,9 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -310,13 +310,14 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem - run_flash_bwd, Is_dropout, Is_causal>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. if constexpr (!Is_dropout) { - run_flash_bwd, false, Is_causal>(params, stream); + run_flash_bwd, false, Is_causal>(params, stream); } } }); } + diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index c8e307417..408652342 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -68,7 +68,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -83,13 +84,13 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); @@ -105,14 +106,14 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); // Allocate predicate tensors for k Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); // Set predicates for k bounds #pragma unroll - for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.vd;} Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); @@ -152,17 +153,19 @@ inline __device__ void clear_dKVaccum(const Params ¶ms) { const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + const index_t row_offset_dk_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + const index_t row_offset_dv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.vd_rounded; - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dk_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dv_accum), + Shape, Int>{}, Stride, _1>{}); typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); @@ -196,7 +199,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -207,10 +210,10 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), @@ -230,7 +233,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); @@ -251,7 +254,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); #pragma unroll @@ -284,7 +287,8 @@ inline __device__ void convert_dKV(const Params ¶ms) { const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; @@ -293,21 +297,23 @@ inline __device__ void convert_dKV(const Params ¶ms) { + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + const index_t row_offset_dk_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + const index_t row_offset_dv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.vd_rounded; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dk_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dv_accum), + Shape, Int>{}, + Stride, _1>{}); Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdKV{}); @@ -331,8 +337,8 @@ inline __device__ void convert_dKV(const Params ¶ms) { Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); @@ -361,17 +367,22 @@ inline __device__ void convert_dKV(const Params ¶ms) { cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + Tensor cdK= make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKcdK = gmem_thr_copy_dKV.partition_D(cdK); + Tensor tdVcdV = gmem_thr_copy_dKV.partition_D(cdV); + Tensor tdKpdK = make_tensor(make_shape(size<2>(tdKgdK))); + Tensor tdVpdV = make_tensor(make_shape(size<2>(tdVgdV))); #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tdKpdK); ++k) { tdKpdK(k) = get<1>(tdKcdK(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tdVpdV); ++k) { tdVpdV(k) = get<1>(tdVcdV(0, 0, k)) < params.vd; } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN ); } diff --git a/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..9834bfbe4 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_sm80.cu new file mode 100644 index 000000000..8bfa8623c --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..35ce26dbe --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_sm80.cu new file mode 100644 index 000000000..17521c9d2 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_sm80.h b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_sm80.h new file mode 100644 index 000000000..8fd4acf28 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim128_vdim256_sm80.h @@ -0,0 +1,42 @@ +#include "flash_bwd_launch_template.h" + +template +void run_mha_bwd_qkdim128_vdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 128; + constexpr static int VHeaddim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + constexpr static int Br = 64; + constexpr static int Bc = 64; + constexpr static int smem_size = 2 *(Br * QKHeaddim * 2 /*Q with double buffer*/ + Br * VHeaddim /* dO*/ + Bc * QKHeaddim /*K, dK*/ + Bc * VHeaddim /*V, dV*/ + + Br * Bc * 2 /*dS, P*/); + // run_flash_bwd>(params, stream); + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. + // run_flash_bwd>(params, stream); + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // A100 shared memory spill + // run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + } else { + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_bwd>(params, stream); + + // run_flash_bwd>(params, stream); + }); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..222fa2cb0 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..61e47b53c --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..0b5f0c766 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_sm80.cu new file mode 100644 index 000000000..3fe80fd91 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_sm80.h b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_sm80.h new file mode 100644 index 000000000..71e550db4 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim192_vdim128_sm80.h @@ -0,0 +1,22 @@ +#include "flash_bwd_launch_template.h" + +template +void run_mha_bwd_qkdim192_vdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 192; + constexpr static int VHeaddim = 128; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 136 * 1024) { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_causal_sm80.cu new file mode 100644 index 000000000..7023d4741 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_sm80.cu new file mode 100644 index 000000000..0f6371b41 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_causal_sm80.cu new file mode 100644 index 000000000..285bca814 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_sm80.cu new file mode 100644 index 000000000..8be40bb82 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_sm80.h b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_sm80.h new file mode 100644 index 000000000..9ce14f6a5 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim32_vdim64_sm80.h @@ -0,0 +1,32 @@ +#include "flash_bwd_launch_template.h" + +template +void run_mha_bwd_qkdim32_vdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 32; + constexpr static int VHeaddim = 64; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + constexpr static int Br = 128; + constexpr static int Bc = 128; + constexpr static int smem_size = 2 *(Br * QKHeaddim * 2 /*Q with double buffer*/ + Br * VHeaddim /* dO*/ + Bc * QKHeaddim /*K, dK*/ + Bc * VHeaddim /*V, dV*/ + + Br * Bc * 2 /*dS, P*/); + // if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB + if (max_smem_per_block >= 104 * 1024) { // 104 KB + if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + } else { // 96 KB + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + }); +} + diff --git a/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..9d18044d4 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..0ceb99220 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..543f16045 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_sm80.cu similarity index 63% rename from csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu rename to csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_sm80.cu index 8690bdb1a..771708192 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_fp16_sm80.cu @@ -1,10 +1,10 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" -#include "flash_fwd_launch_template.h" +#include "flash_bwd_qkdim64_vdim128_sm80.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim64_vdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_sm80.h b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_sm80.h new file mode 100644 index 000000000..b09d032a5 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim64_vdim128_sm80.h @@ -0,0 +1,57 @@ +#include "flash_bwd_launch_template.h" + +template +void run_mha_bwd_qkdim64_vdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 64; + constexpr static int VHeaddim = 128; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Changing AtomLayoutMdQ from 2 to 4 takes the same time + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // This is slightly faster. We want to split M more so we need fewer registers to store LSE. + constexpr static int Br = 64; + constexpr static int Bc = 128; + constexpr static int smem_size = 2 *(Br * QKHeaddim * 2 /*Q with double buffer*/ + Br * VHeaddim /* dO*/ + Bc * QKHeaddim /*K, dK*/ + Bc * VHeaddim /*V, dV*/ + + Br * Bc * 2 /*dS, P*/); + // printf("smem_size = %d\n", smem_size); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // A100 shared memory spill + // run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // This has a lot of register spilling + // run_flash_bwd, Is_dropout>(params, stream); + } else { + // if (params.h == params.h_k) { + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // } else { + // } + } + }); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + + // run_flash_bwd>(params, stream); +} + diff --git a/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..4bd2c82dc --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_sm80.cu new file mode 100644 index 000000000..7536e95ab --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..487006b5a --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_sm80.cu new file mode 100644 index 000000000..9544f59ab --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_sm80.h b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_sm80.h new file mode 100644 index 000000000..79ca59f86 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_qkdim96_vdim192_sm80.h @@ -0,0 +1,33 @@ +#include "flash_bwd_launch_template.h" + +template +void run_mha_bwd_qkdim96_vdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 96; + constexpr static int VHeaddim = 192; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + constexpr static int Br = 64; + constexpr static int Bc = 128; + constexpr static int smem_size = 2 *(Br * QKHeaddim * 2 /*Q with double buffer*/ + Br * VHeaddim /* dO*/ + Bc * QKHeaddim /*K, dK*/ + Bc * VHeaddim /*V, dV*/ + + Br * Bc * 2 /*dS, P*/); + if (max_smem_per_block >= 116 * 1024) { + if constexpr(!Is_dropout) { // 92KB + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } else { // 116 KB + // This is faster for dropout since we don't have many registers to spare + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_bwd, Is_dropout, Is_causal>(params, stream); + } + }); +} + diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu index 9383c1024..8085f173b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu index f03abda48..49e011fca 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu index c616628c8..bcccc6b80 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu index 4ff6b9fbf..0779bd8f9 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu index d6d4371bf..4be6cc5ad 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu index 5af68ac38..121d4b22e 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu index 1ef511a6b..f3f0c5f5b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu index 96abfbd8a..44d0dab1f 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu index 077d25d09..478455719 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu index ea5f265fe..dbc9c3a09 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu index a4a7bc242..f6ad159b6 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu index c30c4a14f..379d9587a 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu deleted file mode 100644 index a12a5f4ad..000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu deleted file mode 100644 index f01dad09c..000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu deleted file mode 100644 index 7ec1e16b7..000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2023, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu index f84e978c9..2755a2571 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu index c52f0417b..57f431d1b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu index f96f7edc6..7781859fd 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu index 9c7c6b93d..274160793 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu index e21d0408c..f19c7c1d3 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu index f377a5b8f..0c1b8f35c 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu index 74e4d66ae..7f1541051 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu index e85db18e3..0776a30a8 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu index 9297e8bb6..dbfb66e71 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu index 8364b1e7e..f03b5a88b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu index 1c6ed7ef0..019754bf2 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu index 3c87573ba..c043773ec 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu index 49fae856a..9c997288b 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu index c5af1cf63..443060de4 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu index b0d6c9928..1f02afa83 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu index c97aa33f8..7bdfa7bcb 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 788f3790e..82b617f2f 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -60,7 +60,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; auto seed_offset = at::cuda::philox::unpack(params.philox_args); @@ -91,9 +92,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_shape(binfo.actual_seqlen_q, params.h, params.vd), make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); @@ -110,7 +111,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.vd; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -136,19 +137,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.q_row_stride, params.q_head_stride, _1{})); - Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), make_shape(binfo.actual_seqlen_k, params.h_k, params.d), make_stride(params.k_row_stride, params.k_head_stride, _1{})); - Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_shape(binfo.actual_seqlen_k, params.h_v, params.vd), make_stride(params.v_row_stride, params.v_head_stride, _1{})); - Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + Tensor gV = local_tile(mV(_, bidh / params.h_h_v_ratio, _), Shape, Int>{}, make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), Shape, Int>{}, @@ -158,8 +159,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); @@ -181,7 +182,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSgS = thr_mma.partition_C(gP); - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling @@ -211,7 +212,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); @@ -227,18 +228,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_QKV.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; } } // Prologue @@ -263,7 +264,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -284,6 +285,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + Tensor cV = make_identity_tensor(make_shape(size<0>(sV), size<1>(sV))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_QKV.partition_S(cV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.vd; } + } + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -304,11 +314,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Advance gV if (masking_step > 0) { - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tVcV, tVpV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -329,7 +339,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKcK, tKpK); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -346,7 +356,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); + // cutlass'bug on vectorization for tile (192,64) + cute::copy(cute::coalesce(rP), cute::coalesce(rP_drop)); dropout.template apply_dropout( rP_drop, block_row_idx, block_col_idx, kNWarps ); @@ -377,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tVcV, tVpV); cute::cp_async_fence(); flash::gemm( @@ -391,7 +402,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKcK, tKpK); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -408,7 +419,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); + // cutlass'bug on vectorization for tile (192,64) + cute::copy(cute::coalesce(rP), cute::coalesce(rP_drop)); dropout.template apply_dropout( rP_drop, block_row_idx, block_col_idx, kNWarps ); @@ -439,15 +451,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. - if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + __syncthreads(); + // if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_shape(binfo.actual_seqlen_q, params.h, params.vd), make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); @@ -461,7 +474,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. @@ -482,7 +495,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.vd; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -507,7 +520,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; using GmemTiledCopyO = std::conditional_t< @@ -538,11 +552,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q - + m_block * kBlockM) * params.d_rounded; + + m_block * kBlockM) * params.vd_rounded; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Shape, Int>{}, + make_stride(Split ? kVHeadDim : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); @@ -558,7 +572,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.vd; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -587,26 +601,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_v_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_v_ratio) * params.v_head_stride; Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.q_row_stride, params.q_head_stride, _1{})); - Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.k_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); @@ -626,7 +640,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling @@ -653,22 +667,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = make_identity_tensor(make_shape(size<0>(sV), size<1>(sV))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - + Tensor tKcK = gmem_thr_copy_QKV.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_QKV.partition_S(cV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.vd; } } // Prologue @@ -684,16 +702,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); @@ -708,18 +726,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) const index_t row_offset_vnew = bidb * params.vnew_batch_stride - + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_v_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. // This maps to accessing the first 64 rows of knew_ptr. Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.knew_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, + Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) @@ -729,18 +747,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { flash::copy_w_min_idx( - tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + tVgVnew, tVgV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { flash::copy_w_min_idx( - tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + tKgKnew, tKgK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory flash::copy_rotary_interleaved( - tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + tKgKnew, tKgK, tRgCos, tRgSin, tKcK, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); @@ -748,7 +766,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } else { // Don't clear OOB_K because we're writing to global memory flash::copy_rotary_contiguous( - tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKcK, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); @@ -789,16 +807,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); @@ -819,7 +837,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); @@ -864,11 +882,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tVcV, tVpV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tVgV, tVsV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -903,7 +921,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKcK, tKpK); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -946,7 +964,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tVcV, tVpV); cute::cp_async_fence(); flash::gemm( @@ -970,7 +988,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKcK, tKpK); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1009,21 +1027,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // sOaccum is larger than sQ, so we need to syncthreads here // TODO: allocate enough smem for sOaccum - if constexpr (Split) { __syncthreads(); } - + __syncthreads(); + // if constexpr (Split) { __syncthreads(); } + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q - + m_block * kBlockM) * params.d_rounded; + + m_block * kBlockM) * params.vd_rounded; const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) ) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Shape, Int>{}, + make_stride(Split ? kVHeadDim : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } @@ -1038,7 +1057,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. @@ -1059,7 +1078,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.vd; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1110,7 +1129,8 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kQKHeadDim = Kernel_traits::kQKHeadDim; + constexpr int kVHeadDim = Kernel_traits::kVHeadDim; constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); @@ -1212,10 +1232,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } __syncthreads(); - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + const index_t row_offset_oaccum = bidx * kBlockM * params.vd_rounded; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); + Shape, Int>{}, + Stride, _1>{}); constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( @@ -1230,13 +1250,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { clear(tOrO); // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); // Repeat the partitioning with identity layouts Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.vd; } } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { @@ -1256,7 +1276,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.vd_rounded; } // if (cute::thread0()) { print_tensor(tOrO); } diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 900cf4671..ba4c29d8b 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -24,7 +24,7 @@ // Use a macro to clean up kernel definitions #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) @@ -60,7 +60,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { @@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel;// TODO: Check if this is correct // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -103,7 +103,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool is_even_K = params.d == Kernel_traits::kQKHeadDim; //TODO: Check if this is correct BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { @@ -114,7 +114,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // TODO: Check if this is correct // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { @@ -134,7 +134,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // We want kBlockM to be as small as possible for more parallelism. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // If headdim is divisible by 64, then we set kBlockM = 8, etc. - constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + constexpr static int kBlockM = Kernel_traits::kQKHeadDim % 128 == 0 ? 4 : (Kernel_traits::kQKHeadDim % 64 == 0 ? 8 : 16); // TODO: Check if this is correct dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { @@ -157,21 +157,21 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } } -template +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. // Also for headdim 160 with block size 64 x 128 after the rotary addition. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal>(params, stream); + constexpr static int kBlockN = QKHeaddim <= 64 ? 256 : (QKHeaddim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } @@ -183,14 +183,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -204,15 +204,15 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // These two are always slower // run_flash_fwd>(params, stream); // run_flash_fwd>(params, stream); @@ -230,26 +230,26 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. if (is_sm8x) { if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // 1st ones are good for H100, A100 // 2nd one is good for A6000 bc we get slightly better occupancy } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); } @@ -265,20 +265,20 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { // and 128 x 64 with 8 warps is the fastest for non-causal. if (is_sm8x) { if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } @@ -287,15 +287,15 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } @@ -317,13 +317,13 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } diff --git a/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..b20271f2d --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_sm80.cu new file mode 100644 index 000000000..464e0b283 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..5af5648fa --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_sm80.cu new file mode 100644 index 000000000..62cb67ead --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim128_vdim256_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim128_vdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_sm80.h b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_sm80.h new file mode 100644 index 000000000..3900d1fd7 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim128_vdim256_sm80.h @@ -0,0 +1,41 @@ +#include "flash_fwd_launch_template.h" + +template +void run_mha_fwd_qkdim128_vdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 128; + constexpr static int VHeaddim = 256; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // slow on A100 + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // A100 RuntimeError: CUDA error: an illegal memory access was encountered + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..52dcca482 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..cfe937021 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..82db5ae67 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_sm80.cu new file mode 100644 index 000000000..2c5d5c7e9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim192_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim192_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_sm80.h b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_sm80.h new file mode 100644 index 000000000..8d259d6ca --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim192_vdim128_sm80.h @@ -0,0 +1,19 @@ +#include "flash_fwd_launch_template.h" + +template +void run_mha_fwd_qkdim192_vdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 192; + constexpr static int VHeaddim = 128; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_causal_sm80.cu new file mode 100644 index 000000000..f28444255 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_sm80.cu new file mode 100644 index 000000000..0aa49a111 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_causal_sm80.cu new file mode 100644 index 000000000..b88785f29 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_sm80.cu new file mode 100644 index 000000000..28c42a9b8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim32_vdim64_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim32_vdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_sm80.h b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_sm80.h new file mode 100644 index 000000000..4c4941471 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_sm80.h @@ -0,0 +1,11 @@ + +#include "flash_fwd_launch_template.h" + +template +void run_mha_fwd_qkdim32_vdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 32; + constexpr static int VHeaddim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..762253d09 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..86a7616fe --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..0074f41ca --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_sm80.cu new file mode 100644 index 000000000..7578c123f --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim64_vdim128_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim64_vdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_sm80.h b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_sm80.h new file mode 100644 index 000000000..3ec8ee12d --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim64_vdim128_sm80.h @@ -0,0 +1,22 @@ +#include "flash_fwd_launch_template.h" + +template +void run_mha_fwd_qkdim64_vdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 64; + constexpr static int VHeaddim = 128; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..a140b8d33 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_sm80.cu new file mode 100644 index 000000000..ee39b3da2 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..8943b8922 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_sm80.cu new file mode 100644 index 000000000..ce4b051a3 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_qkdim96_vdim192_sm80.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_qkdim96_vdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_sm80.h b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_sm80.h new file mode 100644 index 000000000..bd106822d --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_qkdim96_vdim192_sm80.h @@ -0,0 +1,26 @@ +#include "flash_fwd_launch_template.h" + +template +void run_mha_fwd_qkdim96_vdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int QKHeaddim = 96; + constexpr static int VHeaddim = 192; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu index a959c9ceb..00bbaa081 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu index e608e308e..ef2649ee6 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu index 3dd74e273..e610f55da 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu index addacedf4..dd0018f44 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu index 8ace7bda9..2b05a20aa 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu index 1e133ec1a..78e309cdd 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu index 1723c69e0..0504ed5b9 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu index 892d2352a..f21f65c9a 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu index d07ee0af2..b7fc5e1b6 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu index 23cfa59d5..2364925ef 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu index 273a28442..049afac12 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu index 0f588d1f4..3c16d8f6e 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu index 370fe9ca3..fb707522e 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu index 508f07f7d..94f299c90 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu index 019ded67f..0c7ed2c67 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu index 708f5542a..8367a6f9f 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu index 5a205b7e7..ce3ee1383 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu index 2c576f118..3f8a058c2 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu index 484a15e93..bfcb6e98a 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu index 5474ae89d..2abfb9e72 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu index 8c7da41dd..aa61ba301 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu index 93f29dea8..4906716d7 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu index 1e2e12b8c..8d34ac42e 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu index 16c34ed3f..9fc79fbd2 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu index 50080c47e..9002f16c2 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu index ae56ddd4c..76d0c69a8 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu index ed305767e..fe1014408 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu index 022064656..611f0a4c1 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..1300e01d6 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_sm80.cu new file mode 100644 index 000000000..754b5d256 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..e72b90ca9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_sm80.cu new file mode 100644 index 000000000..c6dd9c923 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim128_vdim256_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..db38107b7 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..62cdffd8a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..566dbf250 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_sm80.cu new file mode 100644 index 000000000..9f3023f8f --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim192_vdim128_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_causal_sm80.cu similarity index 72% rename from csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu rename to csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_causal_sm80.cu index b06ae5ace..2da6200cd 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_causal_sm80.cu @@ -1,7 +1,7 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_sm80.cu similarity index 72% rename from csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu rename to csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_sm80.cu index ea024d9ab..138d565e7 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_bf16_sm80.cu @@ -1,7 +1,7 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_causal_sm80.cu similarity index 54% rename from csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu rename to csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_causal_sm80.cu index 8cf2eabed..598fa570f 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_causal_sm80.cu @@ -1,7 +1,7 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_sm80.cu similarity index 54% rename from csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu rename to csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_sm80.cu index b217f3789..4384ec420 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_sm80.cu @@ -1,7 +1,7 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..b700fb21a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_sm80.cu new file mode 100644 index 000000000..e8dcddc4e --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..752f148bb --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_sm80.cu new file mode 100644 index 000000000..0eaf1b0e7 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim64_vdim128_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..d9efa099e --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_sm80.cu new file mode 100644 index 000000000..34e9db839 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..389e228a4 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_sm80.cu new file mode 100644 index 000000000..8d9d9d6f4 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_qkdim96_vdim192_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 119e34956..1b9bc9fbc 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -20,57 +20,97 @@ KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} """ KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); """ KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} """ +KERNEL_IMPL_TEMPLATE_FWD_VDIM = """#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<{DTYPE}, {QKHEAD_DIM}, {VHEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_qkdim{QKHEAD_DIM}_vdim{VHEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} +""" + +KERNEL_IMPL_TEMPLATE_FWD_SPLIT_VDIM = """#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {QKHEAD_DIM}, {VHEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); +""" + +KERNEL_IMPL_TEMPLATE_BWD_VDIM = """#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_<{DTYPE}, {QKHEAD_DIM}, {VHEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_qkdim{QKHEAD_DIM}_vdim{VHEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} +""" @dataclass class Kernel: sm: int dtype: str - head_dim: int + qkhead_dim: int + vhead_dim: int is_causal: bool direction: str @property def template(self) -> str: - if self.direction == "fwd": - return KERNEL_IMPL_TEMPLATE_FWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) - elif self.direction == "bwd": - return KERNEL_IMPL_TEMPLATE_BWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) + if self.qkhead_dim == self.vhead_dim: + if self.direction == "fwd": + return KERNEL_IMPL_TEMPLATE_FWD.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.qkhead_dim, IS_CAUSAL=self.is_causal + ) + elif self.direction == "bwd": + return KERNEL_IMPL_TEMPLATE_BWD.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.qkhead_dim, IS_CAUSAL=self.is_causal + ) + else: + return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.qkhead_dim, IS_CAUSAL=self.is_causal + ) else: - return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal - ) + if self.direction == "fwd": + return KERNEL_IMPL_TEMPLATE_FWD_VDIM.format( + DTYPE=DTYPE_MAP[self.dtype], QKHEAD_DIM=self.qkhead_dim, VHEAD_DIM=self.vhead_dim, IS_CAUSAL=self.is_causal + ) + elif self.direction == "bwd": + return KERNEL_IMPL_TEMPLATE_BWD_VDIM.format( + DTYPE=DTYPE_MAP[self.dtype], QKHEAD_DIM=self.qkhead_dim, VHEAD_DIM=self.vhead_dim, IS_CAUSAL=self.is_causal + ) + else: + return KERNEL_IMPL_TEMPLATE_FWD_SPLIT_VDIM.format( + DTYPE=DTYPE_MAP[self.dtype], QKHEAD_DIM=self.qkhead_dim, VHEAD_DIM=self.vhead_dim, IS_CAUSAL=self.is_causal + ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + if self.qkhead_dim == self.vhead_dim: + return f"flash_{self.direction}_hdim{self.qkhead_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + else: + return f"flash_{self.direction}_qkdim{self.qkhead_dim}_vdim{self.vhead_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: for direction in ["fwd", "fwd_split", "bwd"]: - for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + for dtype, qkhead_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): + for vhead_dim in [qkhead_dim, 2 * qkhead_dim]: + if vhead_dim <= 256: + yield Kernel(sm=sm, dtype=dtype, qkhead_dim=qkhead_dim, vhead_dim=vhead_dim,is_causal=is_causal, direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: diff --git a/csrc/flash_attn/src/generate_switch_headdim.py b/csrc/flash_attn/src/generate_switch_headdim.py new file mode 100644 index 000000000..f7def19d4 --- /dev/null +++ b/csrc/flash_attn/src/generate_switch_headdim.py @@ -0,0 +1,66 @@ +import json +from pathlib import Path + +def write_file(): + TEMPLATE_PRELUDE = """#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +""" + + with open('headdim.json', 'r') as file: + read_list = json.load(file) + + read_list += [ + [32,32], + [64,64], + [96,96], + [128,128], + [160,160], + [192,192], + [256,256], + ] + + read_list = sorted(read_list, key=lambda x: (x[0], x[1])) + + TEMPLATE_BEGIN = """ +#define QKHEADDIM_VHEADDIM_SWITCH(QKHEADDIM, VHEADDIM, ...) \\ + [&] { \\ +""" + + TEMPLATE_BODY = "" + + for qkhead_dim, vhead_dim in read_list[:-1]: + TEMPLATE_BODY += f"""if (QKHEADDIM <= {qkhead_dim} && VHEADDIM <= {vhead_dim}) {{ \\ + constexpr static int kQKHeadDim = {qkhead_dim}; \\ + constexpr static int kVHeadDim = {vhead_dim}; \\ + return __VA_ARGS__(); \\ + }} else """ + + qkhead_dim, vhead_dim = read_list[-1] + TEMPLATE_BODY += f"""if (QKHEADDIM <= {qkhead_dim} && VHEADDIM <= {vhead_dim}) {{ \\ + constexpr static int kQKHeadDim = {qkhead_dim}; \\ + constexpr static int kVHeadDim = {vhead_dim}; \\ + return __VA_ARGS__(); \\ + }} \\ +""" + + TEMPLATE_END = """}() +""" + + TEMPLATE = TEMPLATE_PRELUDE + TEMPLATE_BEGIN + TEMPLATE_BODY + TEMPLATE_END + + # print(TEMPLATE) + with open(Path(__file__).parent.joinpath('static_switch_headdim.h'), 'w') as file: + file.write(TEMPLATE) + +if __name__ == '__main__': + write_file() diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 5a7b74911..3daa46b47 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -12,7 +12,7 @@ using namespace cute; -template +template struct Flash_kernel_traits { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -46,8 +46,8 @@ struct Flash_kernel_traits { }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true -template > +template > struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; @@ -65,10 +65,14 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kQKHeadDim = kQKHeadDim_; + static constexpr int kVHeadDim = kVHeadDim_; + static_assert(kQKHeadDim % 32 == 0); + static_assert(kVHeadDim % 32 == 0); + // TODO: split QK & V + static constexpr int kBlockKSmem = (kQKHeadDim % 64 == 0 && kVHeadDim % 64 == 0) ? 64 : 32; + static constexpr int kBlockKGmem = kQKHeadDim % 128 == 0 ? 128 : (kQKHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using TiledMma = TiledMMA< @@ -83,15 +87,17 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, - Shape, Int>{})); + Shape, Int>{})); - using SmemLayoutKV = decltype(tile_to_shape( + using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomQ{}, - Shape, Int>{})); - + Shape, Int>{})); + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 using SmemLayoutVtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( @@ -100,16 +106,20 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, - Shape, Int>{})); + Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + static constexpr int kSmemKSize = size(SmemLayoutK{}) * sizeof(Element); + static constexpr int kSmemVSize = size(SmemLayoutV{}) * sizeof(Element); + static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element); + static constexpr int kSmemSizeQKV = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKSize + kSmemVSize) : kSmemQSize + kSmemKSize + kSmemVSize; + static constexpr int kSmemSize = std::max(kSmemSizeQKV, kSmemOSize); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static_assert(kQKHeadDim % kGmemElemsPerLoad == 0, "kQKHeadDim must be a multiple of kGmemElemsPerLoad"); + static_assert(kVHeadDim % kGmemElemsPerLoad == 0, "kVHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, @@ -119,6 +129,9 @@ struct Flash_fwd_kernel_traits : public Base { static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; + // for global load thread mapping + static_assert(kBlockN % (kNThreads / kGmemThreadsPerRow) == 0, "kBlockN must be a multiple of kNThreads / kGmemThreadsPerRow"); + static_assert(kBlockM % (kNThreads / kGmemThreadsPerRow) == 0, "kBlockM must be a multiple of kNThreads / kGmemThreadsPerRow"); // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. @@ -160,10 +173,10 @@ struct Flash_fwd_kernel_traits : public Base { // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // No_double_buffer is another option to reduce smem usage, but will slow things down. -template > + typename Base=Flash_kernel_traits > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; @@ -181,12 +194,18 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kQKHeadDim = kQKHeadDim_; + static constexpr int kVHeadDim = kVHeadDim_; + static_assert(kQKHeadDim % 32 == 0); + static_assert(kVHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kQKHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kQKHeadDim % 128 == 0 ? 128 : (kQKHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + static constexpr int kBlockKSmem2 = kVHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem2 = kVHeadDim % 128 == 0 ? 128 : (kVHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle2 = kBlockKSmem2 == 32 ? 2 : 3; + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); @@ -207,25 +226,39 @@ struct Flash_bwd_kernel_traits : public Base { Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; - using SmemLayoutAtomQdO = decltype( + using SmemLayoutAtomQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutQdO = decltype(tile_to_shape( - SmemLayoutAtomQdO{}, - make_shape(Int{}, Int{}))); + using SmemLayoutAtomdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + make_shape(Int{}, Int{}))); + using SmemLayoutdO = decltype(tile_to_shape( + SmemLayoutAtomdO{}, + make_shape(Int{}, Int{}))); - using SmemLayoutAtomKV = decltype( + using SmemLayoutAtomK = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); - using SmemLayoutKV = decltype(tile_to_shape( + using SmemLayoutAtomV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutK = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, - SmemLayoutAtomKV{}, - make_shape(Int{}, Int{}))); - + SmemLayoutAtomK{}, + make_shape(Int{}, Int{}))); + using SmemLayoutV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomV{}, + make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + composition(SmemLayoutK{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN @@ -252,17 +285,28 @@ struct Flash_bwd_kernel_traits : public Base { using SmemCopyAtomPdS = Copy_Atom; - using SmemLayoutQdOtransposed = decltype( - composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + using SmemLayoutQtransposed = decltype( + composition(SmemLayoutQ{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutdOtransposed = decltype( + composition(SmemLayoutdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQtransposed{})); + using SmemLayoutdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdOtransposed{})); - using SmemLayoutAtomdKV = decltype( + using SmemLayoutAtomdK = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutdKV = decltype(tile_to_shape( - SmemLayoutAtomdKV{}, - make_shape(Int{}, Int{}))); + using SmemLayoutAtomdV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdK = decltype(tile_to_shape( + SmemLayoutAtomdK{}, + make_shape(Int{}, Int{}))); + using SmemLayoutdV = decltype(tile_to_shape( + SmemLayoutAtomdV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; using SmemLayoutAtomdQ = decltype( @@ -271,26 +315,29 @@ struct Flash_bwd_kernel_traits : public Base { Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, - make_shape(Int{}, Int{}))); + make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; // Double buffer for sQ - static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * (No_double_buffer ? 1 : 2) * sizeof(Element); + static constexpr int kSmemdOSize = size(SmemLayoutdO{}) * sizeof(Element); + static constexpr int kSmemKSize = size(SmemLayoutK{}) * sizeof(Element); + static constexpr int kSmemVSize = size(SmemLayoutV{}) * sizeof(Element); static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); - static constexpr int kSmemSize = kSmemQdOSize + static constexpr int kSmemSize = kSmemQSize + kSmemdOSize + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize + ? kSmemKSize + kSmemVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKSize + kSmemVSize, kSmemKSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQSize + kSmemdOSize + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + ? kSmemKSize + kSmemVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKSize + kSmemVSize, kSmemKSize + kSmemdSSize + kSmemPSize)); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static_assert(kQKHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static_assert(kVHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a57702f6c..3cb31f9c6 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -112,3 +112,4 @@ return __VA_ARGS__(); \ } \ }() + diff --git a/csrc/flash_attn/src/static_switch_headdim.h b/csrc/flash_attn/src/static_switch_headdim.h new file mode 100644 index 000000000..8d2e97b3e --- /dev/null +++ b/csrc/flash_attn/src/static_switch_headdim.h @@ -0,0 +1,69 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define QKHEADDIM_VHEADDIM_SWITCH(QKHEADDIM, VHEADDIM, ...) \ + [&] { \ + if (QKHEADDIM <= 32 && VHEADDIM <= 32) { \ + constexpr static int kQKHeadDim = 32; \ + constexpr static int kVHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 32 && VHEADDIM <= 64) { \ + constexpr static int kQKHeadDim = 32; \ + constexpr static int kVHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 64 && VHEADDIM <= 64) { \ + constexpr static int kQKHeadDim = 64; \ + constexpr static int kVHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 64 && VHEADDIM <= 128) { \ + constexpr static int kQKHeadDim = 64; \ + constexpr static int kVHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 96 && VHEADDIM <= 96) { \ + constexpr static int kQKHeadDim = 96; \ + constexpr static int kVHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 96 && VHEADDIM <= 192) { \ + constexpr static int kQKHeadDim = 96; \ + constexpr static int kVHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 128 && VHEADDIM <= 128) { \ + constexpr static int kQKHeadDim = 128; \ + constexpr static int kVHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 128 && VHEADDIM <= 256) { \ + constexpr static int kQKHeadDim = 128; \ + constexpr static int kVHeadDim = 256; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 160 && VHEADDIM <= 160) { \ + constexpr static int kQKHeadDim = 160; \ + constexpr static int kVHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 192 && VHEADDIM <= 128) { \ + constexpr static int kQKHeadDim = 192; \ + constexpr static int kVHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 192 && VHEADDIM <= 192) { \ + constexpr static int kQKHeadDim = 192; \ + constexpr static int kVHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (QKHEADDIM <= 256 && VHEADDIM <= 256) { \ + constexpr static int kQKHeadDim = 256; \ + constexpr static int kVHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ecb3515c0..fa86b8c26 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -543,6 +543,7 @@ def forward( ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + ctx.headdim_qk = q.shape[-1] # before padding out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( q, k, @@ -588,8 +589,8 @@ def backward(ctx, dout, *args): ctx.deterministic, rng_state=rng_state, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] + dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension + dk = dk[..., : ctx.headdim_qk] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None @@ -617,6 +618,7 @@ def forward( ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + ctx.headdim_qk = q.shape[-1] # before padding out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( q, k, @@ -675,8 +677,8 @@ def backward(ctx, dout, *args): ctx.deterministic, rng_state=rng_state, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] + dq = dq[..., : ctx.headdim_qk] # We could have padded the head dimension + dk = dk[..., : ctx.headdim_qk] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/headdim.json b/headdim.json new file mode 100644 index 000000000..94aa8daf7 --- /dev/null +++ b/headdim.json @@ -0,0 +1 @@ +[[32, 64], [64, 128], [96, 192], [128, 256], [192, 128]] \ No newline at end of file diff --git a/setup.py b/setup.py index fd67f645b..8c899d84b 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ from pathlib import Path from packaging.version import parse, Version import platform +import json from setuptools import setup, find_packages import subprocess @@ -62,6 +63,30 @@ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +list_headdim = [] +compile_list_headdim = [] +if not SKIP_CUDA_BUILD and not IS_ROCM: + with open('headdim.json', 'r') as file: + list_headdim = json.load(file) + # "csrc/flash_attn/src/flash_fwd_qkdim32_vdim64_fp16_sm80.cu" + for ii in ["fwd", "bwd"]: + for jj in list_headdim: + for kk in ["fp16", "bf16"]: + for ll in ["", "_causal"]: + compile_list_headdim.append( + f"csrc/flash_attn/src/flash_{ii}_qkdim{jj[0]}_vdim{jj[1]}_{kk}{ll}_sm80.cu" + ) + + # "csrc/flash_attn/src/flash_fwd_split_qkdim32_vdim64_fp16_causal_sm80.cu" + for jj in list_headdim: + for kk in ["fp16", "bf16"]: + for ll in ["", "_causal"]: + compile_list_headdim.append( + f"csrc/flash_attn/src/flash_fwd_split_qkdim{jj[0]}_vdim{jj[1]}_{kk}{ll}_sm80.cu" + ) + + from csrc.flash_attn.src.generate_switch_headdim import write_file + write_file() def get_platform(): """ @@ -264,7 +289,8 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", - ], + + ] + compile_list_headdim, extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": append_nvcc_threads( diff --git a/tests/test_flash_attn_head.py b/tests/test_flash_attn_head.py new file mode 100644 index 000000000..b2b55cc49 --- /dev/null +++ b/tests/test_flash_attn_head.py @@ -0,0 +1,1262 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import _get_block_size_n +from flash_attn.layers.rotary import apply_rotary_emb + +MAX_HEADDIM_SM8x = 192 + + +is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) +is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 +is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) +is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) + + +def attn_bias_from_alibi_slopes( + slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None +): + batch, nheads = slopes.shape + device = slopes.device + slopes = rearrange(slopes, "b h -> b h 1 1") + if causal: + return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes + else: + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + relative_pos = torch.abs(row_idx + sk - sq - col_idx) + return -slopes * relative_pos.to(dtype=slopes.dtype) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_kvpacked_ref( + q, + kv, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + return attention_ref( + q, + kv[:, :, 0], + kv[:, :, 1], + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + key_leftpad=key_leftpad, + ) + + +def attention_qkvpacked_ref( + qkv, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + ) + + +def generate_sparsity_mask(seqlen, sparsity=0.3): + repeats = seqlen // 16 // 2 + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + nrow, ncol = seqlen // 16, seqlen // 256 + mask = torch.rand(nrow, ncol, device="cuda") < sparsity + return mask + + +def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + blockmask: (seqlen / 16, seqlen / 256) + attn_mask: (batch_size, seqlen) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen, seqlen) + Output: + output: (batch_size, seqlen, nheads, head_dim) + attention: softmax after dropout + """ + q, k, v = qkv.float().unbind(dim=2) + d = qkv.shape[-1] + seqlen = qkv.shape[1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) + blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") + blockmask = blockmask[:seqlen, :seqlen] + scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) + attention = torch.softmax(scores, dim=-1) + attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) + attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) + attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) + return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) + + +def convert_flash_attn_S_to_softmax( + S, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + head_dim, + is_dropout, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """FlashAttention stores the S matrix in a different way. + Arguments: + S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) + query_padding_mask: (batch_size, seqlen_q_rounded) + key_padding_mask: (batch_size, seqlen_k_rounded) + """ + if causal: + window_size = (window_size[0], 0) + seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] + S_converted = S + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + S.device, + ) + local_mask = F.pad( + local_mask, + (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), + value=True, + ) + S_converted = S_converted.masked_fill(local_mask, 0.0) + + # Need to zero out things not in attention_mask in case S was initialized with random values + # and some of those values aren't overwritten. + seqlen_q_og = ( + query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + ) + if query_padding_mask is not None: + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) + S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) + S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) + return S_converted[:, :, :seqlen_q, :seqlen_k] + + +def normalize_flash_attn_S( + attn_unnorm, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + is_dropout=False, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k, v: (batch_size, seqlen_k, nheads, head_dim) + key_padding_mask: (batch_size, seqlen_q) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + Output: + softmax_lse: (batch_size, nheads, seqlen_q) + softmax_max: (batch_size, nheads, seqlen_q) + """ + if causal: + window_size = (window_size[0], 0) + q, k, v = q.float(), k.float(), v.float() + _, seqlen_q, _, head_dim = q.shape + seqlen_k = k.shape[1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias.to(dtype=scores.dtype) + block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + scores_block = scores.split(block_size_n, dim=-1) + lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) + lse = torch.logsumexp(lse_block, dim=-1) + # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf + # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. + lse[lse == float("-inf")] = float("inf") + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) + attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) + attn_norm = torch.cat( + [ + a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") + for a, m in zip(attn_unnorm_block, cummax_block) + ], + dim=-1, + ) + if query_padding_mask is not None: + attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return attn_norm.to(dtype=attn_unnorm.dtype) + + +def get_dropout_fraction( + dropout_mask, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + """ + if causal: + window_size = (window_size[0], 0) + batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape + dropped = ~dropout_mask + valid = torch.ones_like(dropout_mask) + if query_padding_mask is not None: + dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + if key_padding_mask is not None: + dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + dropout_mask.device, + ) + dropped.masked_fill_(local_mask, False) + valid.masked_fill_(local_mask, False) + dropped_total = dropped.sum() + return dropped.sum() / valid.sum() + + + +@pytest.mark.parametrize("kvpacked", [False]) +# @pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("h,h_k,h_v",[ + (32,4,16), +]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 50.0]) +def test_flash_attn_output( + seqlen_q, seqlen_k, h, h_k, h_v, d, dropout_p, causal, local, alibi, deterministic, dtype, kvpacked, softcap +): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 # 4 + nheads = h # 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = h_k # nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) + nheads_v = h_v + assert nheads % nheads_k == 0 + assert nheads % nheads_v == 0 + assert (not kvpacked) or (nheads_k == nheads_v) + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + out, lse, S_dmask = flash_attn_kvpacked_func( + q, + kv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_v) + attn = normalize_flash_attn_S( + attn_unnorm, + q, + k_rep, + v_rep, + None, + None, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if kvpacked: + ( + dq, + dkv, + ) = torch.autograd.grad(out, (q, kv), g) + dk, dv = dkv.unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + + + + +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("h,h_k,h_v",[ + (32,4,16), +]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (3, 1024), + (1, 339), + (64, 800), + (3, 799), + (64, 2048), + (16, 20000), + (16, 100000), + (128, 128), + (256, 256), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_splitkv( + seqlen_q, seqlen_k,h,h_k,h_v, swap_sq_sk, d, causal, local, alibi, deterministic, dtype +): + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = h # 12 + nheads_k = h_k + nheads_v = h_v + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_v, d, device=device, dtype=dtype, requires_grad=True) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, _ = flash_attn_func( + q, + k, + v, + 0.0, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + mult = 2 if not alibi else 8 + assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + + +# # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("num_splits", [1, 0]) +# # @pytest.mark.parametrize("num_splits", [1]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("new_kv", [False, True]) +# # @pytest.mark.parametrize("new_kv", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +# # @pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +# # @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +# # @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +# # @pytest.mark.parametrize("rotary_interleaved", [False]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# # @pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) +# # @pytest.mark.parametrize("paged_kv_block_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +# # @pytest.mark.parametrize("has_leftpad", [True]) +# # @pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# # @pytest.mark.parametrize('d', [56, 80]) +# # @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize( +# "seqlen_q,seqlen_k", +# [ +# (1, 128), +# (1, 339), +# (3, 1024), +# (64, 800), +# (64, 256), +# (3, 799), +# (64, 2048), +# (16, 20000), +# (1, 128 * 1024), +# (16, 128 * 1024), +# (128, 128), +# ], +# ) +# # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +# def test_flash_attn_kvcache( +# seqlen_q, +# seqlen_k, +# d, +# has_batch_idx, +# has_leftpad, +# paged_kv_block_size, +# rotary_fraction, +# rotary_interleaved, +# seqlen_new_eq_seqlen_q, +# causal, +# local, +# alibi, +# new_kv, +# mha_type, +# num_splits, +# dtype, +# ): +# if seqlen_q > seqlen_k and new_kv: +# pytest.skip() +# if not new_kv and rotary_fraction > 0.0: +# pytest.skip() +# if has_batch_idx and paged_kv_block_size is not None: +# pytest.skip() +# if has_leftpad and paged_kv_block_size is not None: +# pytest.skip() +# device = "cuda" +# # set seed +# torch.random.manual_seed(0) +# batch_size = 2 +# batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 +# nheads = 6 +# # rotary_dim must be a multiple of 16, and must be <= d +# rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 +# nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) +# assert nheads % nheads_k == 0 +# window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) +# q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) +# seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() +# if new_kv: +# k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# else: +# k, v = None, None +# if paged_kv_block_size is None: +# k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# block_table = None +# else: +# ( +# k_cache, +# v_cache, +# block_table, +# k_cache_paged, +# v_cache_paged, +# num_blocks, +# ) = _generate_block_kvcache( +# seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype +# ) +# cache_seqlens = torch.randint( +# 0 if new_kv else 1, +# # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough +# ( +# (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) +# if new_kv +# else (seqlen_k + 1) +# ), +# (batch_size,), +# dtype=torch.int32, +# device=device, +# ) +# if has_leftpad: +# cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) +# if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) +# for i in range(batch_size)]) +# else: +# cache_leftpad = None +# arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") +# cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") +# key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) +# if has_leftpad: +# key_padding_mask = torch.logical_and( +# key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) +# ) +# if has_batch_idx: +# cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ +# :batch_size +# ] +# else: +# cache_batch_idx = None +# if alibi: +# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 +# attn_bias = attn_bias_from_alibi_slopes( +# alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad +# ) +# else: +# alibi_slopes, attn_bias = None, None +# # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) +# if rotary_dim > 0: +# angle = ( +# torch.rand( +# seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, +# rotary_dim // 2, +# device=device, +# ) +# * 2 +# * math.pi +# ) +# cos = torch.cos(angle).to(dtype=dtype) +# sin = torch.sin(angle).to(dtype=dtype) +# if causal or local: +# q_ro = apply_rotary_emb( +# q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# q_ro = rearrange( +# apply_rotary_emb( +# rearrange(q, "b s h d -> b 1 (s h) d"), +# cos, +# sin, +# seqlen_offsets=cache_seqlens, +# interleaved=rotary_interleaved, +# ), +# "b 1 (s h) d -> b s h d", +# s=seqlen_q, +# ) +# # q_ro = q +# k_ro = apply_rotary_emb( +# k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# cos, sin = None, None +# q_ro, k_ro = q, k +# # k_cache[:, 64:] = -1 +# k_cache_ref = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# v_cache_ref = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# if new_kv: +# update_mask = torch.logical_and( +# cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new +# ) +# k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") +# v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") +# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# out = flash_attn_with_kvcache( +# q, +# k_cache if paged_kv_block_size is None else k_cache_paged, +# v_cache if paged_kv_block_size is None else v_cache_paged, +# k, +# v, +# rotary_cos=cos, +# rotary_sin=sin, +# cache_seqlens=cache_seqlens, +# cache_batch_idx=cache_batch_idx, +# cache_leftpad=cache_leftpad, +# block_table=block_table, +# causal=causal, +# window_size=window_size, +# rotary_interleaved=rotary_interleaved, +# alibi_slopes=alibi_slopes, +# num_splits=num_splits, +# ) +# # out = flash_attn_with_kvcache( +# # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size +# # ) +# # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) +# # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) +# # m = qk.amax(-1, keepdim=True) +# # s_tmp = torch.exp((qk - m) / math.sqrt(d)) +# # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) +# # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) +# # probs = torch.softmax(qk, dim=-1) +# out_ref, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# key_leftpad=cache_leftpad, +# ) +# out_pt, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# upcast=False, +# reorder_ops=True, +# key_leftpad=cache_leftpad, +# ) +# print(f"Output max diff: {(out - out_ref).abs().max().item()}") +# print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") +# print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") +# print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + +# # Check that FlashAttention's numerical error is at most twice the numerical error +# # of a Pytorch implementation. +# if new_kv: +# if paged_kv_block_size is None: +# k_cache_select = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# v_cache_select = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# else: +# k_cache_select = rearrange( +# k_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# v_cache_select = rearrange( +# v_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) +# assert torch.equal(v_cache_select, v_cache_ref) +# mult = 3 if not alibi else 5 +# assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + + diff --git a/tests/test_flash_attn_headdim.py b/tests/test_flash_attn_headdim.py new file mode 100644 index 000000000..1a4613d12 --- /dev/null +++ b/tests/test_flash_attn_headdim.py @@ -0,0 +1,935 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) +from flash_attn.bert_padding import pad_input, unpad_input +# from flash_attn.flash_attn_interface import _get_block_size_n +from flash_attn.layers.rotary import apply_rotary_emb + +MAX_HEADDIM_SM8x = 192 + + +is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) +is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 +is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) +is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) + + +def attn_bias_from_alibi_slopes( + slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None +): + batch, nheads = slopes.shape + device = slopes.device + slopes = rearrange(slopes, "b h -> b h 1 1") + if causal: + return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes + else: + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + relative_pos = torch.abs(row_idx + sk - sq - col_idx) + return -slopes * relative_pos.to(dtype=slopes.dtype) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_kvpacked_ref( + q, + kv, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + return attention_ref( + q, + kv[:, :, 0], + kv[:, :, 1], + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + key_leftpad=key_leftpad, + ) + + +def attention_qkvpacked_ref( + qkv, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + ) + + +def generate_sparsity_mask(seqlen, sparsity=0.3): + repeats = seqlen // 16 // 2 + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + nrow, ncol = seqlen // 16, seqlen // 256 + mask = torch.rand(nrow, ncol, device="cuda") < sparsity + return mask + + +def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + blockmask: (seqlen / 16, seqlen / 256) + attn_mask: (batch_size, seqlen) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen, seqlen) + Output: + output: (batch_size, seqlen, nheads, head_dim) + attention: softmax after dropout + """ + q, k, v = qkv.float().unbind(dim=2) + d = qkv.shape[-1] + seqlen = qkv.shape[1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) + blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") + blockmask = blockmask[:seqlen, :seqlen] + scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) + attention = torch.softmax(scores, dim=-1) + attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) + attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) + attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) + return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) + + +def convert_flash_attn_S_to_softmax( + S, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + head_dim, + is_dropout, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """FlashAttention stores the S matrix in a different way. + Arguments: + S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) + query_padding_mask: (batch_size, seqlen_q_rounded) + key_padding_mask: (batch_size, seqlen_k_rounded) + """ + if causal: + window_size = (window_size[0], 0) + seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] + S_converted = S + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + S.device, + ) + local_mask = F.pad( + local_mask, + (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), + value=True, + ) + S_converted = S_converted.masked_fill(local_mask, 0.0) + + # Need to zero out things not in attention_mask in case S was initialized with random values + # and some of those values aren't overwritten. + seqlen_q_og = ( + query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + ) + if query_padding_mask is not None: + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) + S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) + S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) + return S_converted[:, :, :seqlen_q, :seqlen_k] + +def _get_block_size_n_headdim(device, qk_head_dim, v_head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert qk_head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if qk_head_dim <= 32: + return 128 + if qk_head_dim <= 64: + return 128 if not is_dropout else 64 + elif qk_head_dim <= 96: + return 64 + elif qk_head_dim <= 128: + # v_head_dim + if v_head_dim==256 and is_dropout: + return 64 + if is_sm8x: + return 64 if (not is_dropout and is_causal) else 32 + else: + return 64 if not is_dropout else 32 + elif qk_head_dim <= 160: + if is_sm8x: + return 64 + else: + return 32 + elif qk_head_dim <= 192: + return 64 + elif qk_head_dim <= 224: + return 64 + elif qk_head_dim <= 256: + return 64 + + +def normalize_flash_attn_S( + attn_unnorm, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + is_dropout=False, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k, v: (batch_size, seqlen_k, nheads, head_dim) + key_padding_mask: (batch_size, seqlen_q) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + Output: + softmax_lse: (batch_size, nheads, seqlen_q) + softmax_max: (batch_size, nheads, seqlen_q) + """ + if causal: + window_size = (window_size[0], 0) + q, k, v = q.float(), k.float(), v.float() + _, seqlen_q, _, head_dim = q.shape + seqlen_k = k.shape[1] + v_head_dim = v.shape[-1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias.to(dtype=scores.dtype) + block_size_n = _get_block_size_n_headdim(scores.device, head_dim, v_head_dim, is_dropout, causal) + scores_block = scores.split(block_size_n, dim=-1) + lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) + lse = torch.logsumexp(lse_block, dim=-1) + # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf + # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. + lse[lse == float("-inf")] = float("inf") + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) + attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) + attn_norm = torch.cat( + [ + a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") + for a, m in zip(attn_unnorm_block, cummax_block) + ], + dim=-1, + ) + if query_padding_mask is not None: + attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return attn_norm.to(dtype=attn_unnorm.dtype) + + +def get_dropout_fraction( + dropout_mask, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + """ + if causal: + window_size = (window_size[0], 0) + batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape + dropped = ~dropout_mask + valid = torch.ones_like(dropout_mask) + if query_padding_mask is not None: + dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + if key_padding_mask is not None: + dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + dropout_mask.device, + ) + dropped.masked_fill_(local_mask, False) + valid.masked_fill_(local_mask, False) + dropped_total = dropped.sum() + return dropped.sum() / valid.sum() + + +@pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d,v_d", [ + (32, 64), + (64, 128), + (96, 192), + (128, 256) + ]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 50.0]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, v_d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap +): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap + assert kvpacked == False + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, v_d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + attn = normalize_flash_attn_S( + attn_unnorm, + q, + k_rep, + v_rep, + None, + None, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if ((d <= MAX_HEADDIM_SM8x and v_d <= MAX_HEADDIM_SM8x) or dropout_p == 0) or (is_sm80 or is_sm90): + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + + if ((d <= MAX_HEADDIM_SM8x and v_d <= MAX_HEADDIM_SM8x) or dropout_p == 0) or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d,v_d", [ + (32, 64), + (64, 128), + (96, 192), + (128, 256) + ]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (3, 1024), + (1, 339), + (64, 800), + (3, 799), + (64, 2048), + (16, 20000), + (16, 100000), + (128, 128), + (256, 256), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_splitkv( + seqlen_q, seqlen_k, swap_sq_sk, d, v_d, causal, local, alibi, deterministic, dtype +): + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = 12 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, v_d, device=device, dtype=dtype, requires_grad=True) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, _ = flash_attn_func( + q, + k, + v, + 0.0, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + mult = 2 if not alibi else 8 + assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + +