Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Nov 13, 2024
1 parent e1c89ba commit 1bf0d76
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions python/text_utils/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import json
import os.path
from typing import Any

Expand Down Expand Up @@ -31,7 +32,7 @@ def save_checkpoint(
optimizer_state_dict: dict[str, Any] | None = None,
lr_scheduler_state_dict: dict[str, Any] | None = None,
loss_fn_state_dict: dict[str, Any] | None = None,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
Saves a checkpoint to a file.
Expand Down Expand Up @@ -63,23 +64,28 @@ def save_checkpoint(
"optimizer_state_dict": optimizer_state_dict,
"lr_scheduler_state_dict": lr_scheduler_state_dict,
"loss_fn_state_dict": loss_fn_state_dict,
**kwargs
**kwargs,
}
torch.save(state, f=checkpoint_path)


def load_checkpoint(
checkpoint_path: str,
device: torch.device = torch.device("cpu")
checkpoint_path: str, device: torch.device = torch.device("cpu")
) -> dict[str, Any]:
return torch.load(checkpoint_path, map_location=device)


def load_text_file(
path: str
) -> list[str]:
def load_text_file(path: str) -> list[str]:
text = []
with open(path, "r", encoding="utf8") as inf:
for line in inf:
text.append(line.rstrip("\r\n"))
return text


def load_jsonl_file(path: str) -> list:
data = []
with open(path, "r", encoding="utf8") as inf:
for line in inf:
data.append(json.loads(line.rstrip("\r\n")))
return data

0 comments on commit 1bf0d76

Please sign in to comment.