Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
sakasa authored Jun 16, 2020
1 parent edbcc12 commit d7bb7fb
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion MachineLearning/DeepLearning/PyTorch/memo.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
## モデルの保存
```python
save_path = './dir/file.pth'
torch.save(model.state_dict(), save_path) # 学習済みモデルパラメータ, 保存先パス
```

## モデルの読み込み
```python
load_path = './dir/file.pth'
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(save_path))
model.load_state_dict(torch.load(load_path))
```

### GPU上で保存されたパラメータをGPU上でロードする場合
```python
load_path = './dir/file.pth'
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(load_path, map_location={'cuda:0': 'cpu'}))
```

0 comments on commit d7bb7fb

Please sign in to comment.