Skip to content

Commit

Permalink
fixed memory issue (more torch compatible), added deletion of non use…
Browse files Browse the repository at this point in the history
…d tensor and empty cache to release memory and benchmarked up to 128 enc dec chunks (with and without release cache).
  • Loading branch information
einrone committed Jan 9, 2025
1 parent f1c0633 commit b032359
Showing 1 changed file with 68 additions and 45 deletions.
113 changes: 68 additions & 45 deletions bris/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,27 @@ def __init__(


super().__init__(*args, **kwargs)

#Lazy init
self.model_comm_group = None
self.model_comm_group_id = 0
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.legacy = True

def set_model_comm_group(
self,
model_comm_group: ProcessGroup,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
model_comm_group_size: int,
model_comm_group_id: int = None,
model_comm_group_rank: int = None,
model_comm_num_groups: int = None,
model_comm_group_size: int = None,
) -> None:
self.model_comm_group = model_comm_group
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_size = model_comm_group_size
if not self.legacy:
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_size = model_comm_group_size

def set_reader_groups(
self,
Expand Down Expand Up @@ -87,103 +88,125 @@ def __init__(
data_reader: Iterable,
forecast_length: int,
variable_indices: list,
release_cache: bool=False,
**kwargs
) -> None:
super().__init__(
*args,**kwargs)

super().__init__(*args,**kwargs)

self.model=checkpoint.model
self.data_indices = self.model.data_indices
self.metadata = checkpoint.metadata

#TODO: where should these come from, add asserts?
self.frequency = self.metadata["config"]["data"]["frequency"]
self.frequency = self.metadata.config.data.frequency #["config"]["data"]["frequency"]
if isinstance(self.frequency, str) and self.frequency[-1] == 'h':
self.frequency = int(self.frequency[0:-1])

self.forecast_length = forecast_length
self.latitudes = data_reader.latitudes
self.longitudes = data_reader.longitudes
self.variable_indices = variable_indices[0] # Assume we only have one decoder


self.set_static_forcings(data_reader, self.metadata["config"]["data"]["forcing"])
# this makes it backwards compatible with older
# anemoi-models versions. I.e legendary gnome, etc..
if (
hasattr(self.data_indices, "internal_model") and hasattr(self.data_indices,"internal_data")
):
self.internal_model = self.data_indices.internal_model
self.internal_data = self.data_indices.internal_data
else:
self.internal_model = self.data_indices.model
self.internal_data = self.data_indices.data

self.set_static_forcings(data_reader, self.metadata.config.data.forcing)

self.model.eval()
self.release_cache = release_cache

def set_static_forcings(self, data_reader, selection):

self.static_forcings = {}
data = torch.from_numpy(data_reader[0].squeeze(axis=1).swapaxes(0,1))
data_normalized = self.model.pre_processors(data, in_place=False)
data_normalized = self.model.pre_processors(data, in_place=True)

# np.ndarray are by default set to np.float64 and torch tensor torch.float32
# without explicit converting and casting to torch.float32
# appending an numpy array to torch.tensor might not automatically cast np.ndarray to torch.float32
# i.e the new updated x tensor internally will have torch.float64, resulting in memory increase
# both CPU/GPU RAM

if "cos_latitude" in selection:
self.static_forcings["cos_latitude"] = np.cos(data_reader.latitudes * np.pi / 180.)
self.static_forcings["cos_latitude"] = torch.from_numpy(np.cos(data_reader.latitudes * np.pi / 180.)).float()

if "sin_latitude" in selection:
self.static_forcings["sin_latitude"] = np.sin(data_reader.latitudes * np.pi / 180.)
self.static_forcings["sin_latitude"] = torch.from_numpy(np.sin(data_reader.latitudes * np.pi / 180.)).float()

if "cos_longitude" in selection:
self.static_forcings["cos_longitude"] = np.cos(data_reader.longitudes * np.pi / 180. )
self.static_forcings["cos_longitude"] = torch.from_numpy(np.cos(data_reader.longitudes * np.pi / 180. )).float()

if "sin_longitude" in selection:
self.static_forcings["sin_longitude"] = np.sin(data_reader.longitudes * np.pi / 180.)
self.static_forcings["sin_longitude"] = torch.from_numpy(np.sin(data_reader.longitudes * np.pi / 180.)).float()

if "lsm" in selection:
self.static_forcings["lsm"] = data_normalized[..., data_reader.name_to_index["lsm"]]
self.static_forcings["lsm"] = data_normalized[..., data_reader.name_to_index["lsm"]].float()

if "z" in selection:
self.static_forcings["z"] = data_normalized[..., data_reader.name_to_index["z"]]
self.static_forcings["z"] = data_normalized[..., data_reader.name_to_index["z"]].float()

del data_normalized

def forward(self, x: torch.Tensor)-> torch.Tensor:
return self.model(x, self.model_comm_group)

def advance_input_predict(self, x, y_pred, time):
data_indices = self.model.data_indices

x = x.roll(-1, dims=1)

#Get prognostic variables:
x[:, -1, :, :, data_indices.internal_model.input.prognostic] = y_pred[..., data_indices.internal_model.output.prognostic]
x[:, -1, :, :, self.internal_model.input.prognostic] = y_pred[..., self.internal_model.output.prognostic]

forcings = get_dynamic_forcings(time, self.latitudes, self.longitudes, self.metadata["config"]["data"]["forcing"])
forcings = get_dynamic_forcings(time, self.latitudes, self.longitudes, self.metadata.config.data.forcing)
forcings.update(self.static_forcings)

for forcing, value in forcings.items():
if type(value) == np.ndarray:
x[:, -1, :, :, data_indices.internal_model.input.name_to_index[forcing]] = torch.from_numpy(value)
if isinstance(value, np.ndarray):
x[:, -1, :, :, self.internal_model.input.name_to_index[forcing]] = torch.from_numpy(value).to(dtype=x.dtype)#, device=x.device)
else:
x[:, -1, :, :, data_indices.internal_model.input.name_to_index[forcing]] = value

x[:, -1, :, :, self.internal_model.input.name_to_index[forcing]] = value #torch.from_numpy(np.array(value)).to(dtype=x.dtype, device=x.device)
return x

@torch.inference_mode
def predict_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
data_indices = self.model.data_indices
multistep = self.metadata["config"]["training"]["multistep_input"]
multistep = self.metadata.config.training.multistep_input

batch = self.allgather_batch(batch)

batch, time_stamp = batch
time = np.datetime64(time_stamp[0], 'h') #Consider not forcing 'h' here and instead generalize time + self.frequency
times = [time]
y_preds = np.zeros((batch.shape[0], self.forecast_length, batch.shape[-2], len(self.variable_indices)))
y_preds = torch.empty((batch.shape[0], self.forecast_length, batch.shape[-2], len(self.variable_indices)), dtype=batch.dtype, device="cpu")#.cpu()

#Insert analysis for t=0
y_analysis = batch[:,multistep-1,0,...]
y_analysis[...,data_indices.internal_data.output.diagnostic] = 0. #Set diagnostic variables to zero
y_preds[:,0,...] = y_analysis[...,self.variable_indices].cpu().to(torch.float32).numpy()
y_analysis = batch[:,multistep-1,...].cpu()
y_analysis[...,self.internal_data.output.diagnostic] = 0. #Set diagnostic variables to zero
y_preds[:,0,...] = y_analysis[...,self.variable_indices]

#Possibly have to extend this to handle imputer, see _step in forecaster.
batch = self.model.pre_processors(batch, in_place=False)
x = batch[..., data_indices.internal_data.input.full]
with torch.amp.autocast(device_type= "cuda", dtype=torch.bfloat16):
batch = self.model.pre_processors(batch, in_place=True)
x = batch[..., self.internal_data.input.full]

with torch.autocast(device_type= "cuda", dtype=torch.bfloat16):
for fcast_step in range(self.forecast_length-1):
y_pred = self(x)
time += self.frequency
x = self.advance_input_predict(x, y_pred, time)
y_preds[:, fcast_step+1, ...] = self.model.post_processors(y_pred, in_place=False)[:,0,...,self.variable_indices].cpu().to(torch.float32).numpy()
y_preds[:, fcast_step+1] = self.model.post_processors(y_pred, in_place=True)[:,0,:,self.variable_indices].cpu()

times.append(time)
return {"pred": [y_preds], "times": times, "group_rank": self.model_comm_group_rank, "ensemble_member": 0}

if self.release_cache:
del y_pred
torch.cuda.empty_cache()
return {"pred": [y_preds.to(torch.float32).numpy()], "times": times, "group_rank": self.model_comm_group_rank, "ensemble_member": 0}

def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor:
return batch #Not implemented properly

Expand Down Expand Up @@ -212,4 +235,4 @@ def advance_input_predict(self, x, y_pred):

@torch.inference_mode
def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
pass
pass

0 comments on commit b032359

Please sign in to comment.