From d7bb7fba51a91d73856680dda52dd2b818efed9e Mon Sep 17 00:00:00 2001 From: pickles Date: Tue, 16 Jun 2020 10:22:27 +0900 Subject: [PATCH] Update --- MachineLearning/DeepLearning/PyTorch/memo.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/MachineLearning/DeepLearning/PyTorch/memo.md b/MachineLearning/DeepLearning/PyTorch/memo.md index 5d812f5..b13000a 100644 --- a/MachineLearning/DeepLearning/PyTorch/memo.md +++ b/MachineLearning/DeepLearning/PyTorch/memo.md @@ -1,11 +1,20 @@ ## モデルの保存 ```python +save_path = './dir/file.pth' torch.save(model.state_dict(), save_path) # 学習済みモデルパラメータ, 保存先パス ``` ## モデルの読み込み ```python +load_path = './dir/file.pth' model = ModelClass(*args, **kwargs) -model.load_state_dict(torch.load(save_path)) +model.load_state_dict(torch.load(load_path)) +``` + +### GPU上で保存されたパラメータをGPU上でロードする場合 +```python +load_path = './dir/file.pth' +model = ModelClass(*args, **kwargs) +model.load_state_dict(torch.load(load_path, map_location={'cuda:0': 'cpu'})) ```