티스토리 뷰

본 게시글은 해당 링크 참고하여 작성된 게시글입니다.

 

Multi GPUs를 사용하여 Model을 학습시키고 Weights을 저장하면 모든 Parameter의 Key값에 "module."이 붙은채로 저장됩니다.

 

그런데, 이렇게 되면 Model의 멤버 변수(self.*)를 클래스 외부에서 접근할때 코딩 시 객체명.* 으로 접근하지 못하고 객체명.module.*과 같이 "module"을 붙여주어야 접근이 가능해집니다. 이렇게 되면, 코드를 작성할때 항상 module의 존재를 신경써주어야 합니다.

 

문제를 파악해봅시다. module이 붙는 이유는 Model을 병렬화 시켰기 때문이고, Model을 병렬화 시켰던 이유는 Load하고자하는 Weights의 Key값에 'module.'이 붙었기 때문입니다. 해결 방법은 간단합니다. Key값에 붙는 'module.'을 제거하면됩니다. 아래는 예시코드입니다.

 

model = ...
checkpoint = ...

state_dict = checkpoint['state_dict']
keys = state_dict.keys()

values = state_dict.values()

new_keys = []

for key in keys:
  new_key = key[7:]    # remove the 'module.'
  new_keys.append(new_key)

new_dict = OrderedDict(list(zip(new_keys, values)))
model.load_state_dict(new_dict)
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함