Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

meta tensor stuff #769

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash_attn/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def __init__(
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
inv_freq = self._compute_inv_freq('cuda' if device == 'meta' else device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

def load_state_dict(self, state_dict, strict=True):
def load_state_dict(self, state_dict, strict=True, **hf_kwargs):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
Expand All @@ -690,7 +690,7 @@ def load_state_dict(self, state_dict, strict=True):
ln_bias = state_dict.pop("transformer.ln_0.bias")
state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
return super().load_state_dict(state_dict, strict=strict)
return super().load_state_dict(state_dict, strict=strict, **hf_kwargs)


def shard_state_dict_tp(state_dict, config, world_size, rank):
Expand Down
4 changes: 4 additions & 0 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def __init__(
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if device == 'meta':
device = 'cuda' # do something else
return torch.empty(
batch_size,
max_seqlen,
Expand Down Expand Up @@ -819,6 +821,8 @@ def __init__(
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if device == 'meta':
device = 'cuda'
return torch.empty(
batch_size,
max_seqlen,
Expand Down
13 changes: 7 additions & 6 deletions flash_attn/utils/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from transformers.utils.hub import cached_file, get_checkpoint_shard_files


def state_dict_from_pretrained(model_name, device=None, dtype=None):
def state_dict_from_pretrained(model_name, device=None, dtype=None, load_safe=False):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
is_sharded = False
load_safe = False
resolved_archive_file = None

weights_path = os.path.join(model_name, WEIGHTS_NAME)
Expand Down Expand Up @@ -45,11 +44,13 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
is_sharded = True
load_safe = True
else: # Try loading from HF hub instead of from local files
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, SAFE_WEIGHTS_NAME if load_safe else WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
resolved_archive_file = cached_file(
model_name, SAFE_WEIGHTS_INDEX_NAME if load_safe else WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True

Expand Down