티스토리 뷰
Deep Learning/PyTorch
PyTorch - Multi GPUs로 학습된 Model의 Weights을 DataParallel을 호출하지 않고 Load하는 방법
developer0hye 2020. 10. 8. 01:41본 게시글은 해당 링크를 참고하여 작성된 게시글입니다.
Multi GPUs를 사용하여 Model을 학습시키고 Weights을 저장하면 모든 Parameter의 Key값에 "module."이 붙은채로 저장됩니다.
그런데, 이렇게 되면 Model의 멤버 변수(self.*)를 클래스 외부에서 접근할때 코딩 시 객체명.* 으로 접근하지 못하고 객체명.module.*과 같이 "module"을 붙여주어야 접근이 가능해집니다. 이렇게 되면, 코드를 작성할때 항상 module의 존재를 신경써주어야 합니다.
문제를 파악해봅시다. module이 붙는 이유는 Model을 병렬화 시켰기 때문이고, Model을 병렬화 시켰던 이유는 Load하고자하는 Weights의 Key값에 'module.'이 붙었기 때문입니다. 해결 방법은 간단합니다. Key값에 붙는 'module.'을 제거하면됩니다. 아래는 예시코드입니다.
model = ...
checkpoint = ...
state_dict = checkpoint['state_dict']
keys = state_dict.keys()
values = state_dict.values()
new_keys = []
for key in keys:
new_key = key[7:] # remove the 'module.'
new_keys.append(new_key)
new_dict = OrderedDict(list(zip(new_keys, values)))
model.load_state_dict(new_dict)
'Deep Learning > PyTorch' 카테고리의 다른 글
[PyTorch] Get a single batch from DataLoader without iterating (0) | 2021.01.24 |
---|---|
[PyTorch] torch.exp 와 auto mixed precision (0) | 2021.01.19 |
[PyTorch] Depthwise Convolutional Layer 속도 향상 방법 (0) | 2021.01.08 |
Image(Numpy, Opencv) To Tensor (0) | 2020.12.22 |
layer 별 learning rate 할당 방법 (0) | 2020.09.23 |
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- Lowest Common Ancestor
- 백준 11053
- 단축키
- 순열
- 백준
- PyCharm
- 조합
- 백준 11437
- cosine
- 백트래킹
- LCA
- C++ Deploy
- 자료구조
- 백준 1766
- FairMOT
- 이분탐색
- 가장 긴 증가하는 부분 수열
- 위상 정렬 알고리즘
- 파이참
- MOT
- ㅂ
- 인공지능을 위한 선형대수
- 문제집
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
글 보관함