티스토리 뷰

Deep Learning/PyTorch

[PyTorch] AMP Loss dtype Issue

developer0hye 2021. 6. 23. 22:27

기존에 Loss를 항상 float32타입으로 계산하도록 코드를 짰었다.

 

이유는 AMP를 쓸때 Loss를 float16타입으로 계산하면 정밀도가 떨어지고 오버플로우 문제가 발생할걸로 예상했기 때문... 근데 이렇게 하니까 AMP 사용해서 학습시키니 모델 성능이 크게 저하됨(VOC 데이터셋 기준 mAP 가 2.0 정도 드랍됨, 모델은 CenterNet-ResNet18 기준)

 

이부분을 모델 아웃풋 텐서의 타입으로 맞춰주니 성능 저하 없고 오히려 성능이 향상됨

 

아래 프로젝트 개발하다 해당 문제를 발견하게됨

 

https://github.com/developer0hye/Simple-CenterNet

 

developer0hye/Simple-CenterNet

PyTorch Implementation of CenterNet(Object as Points) - developer0hye/Simple-CenterNet

github.com

 

 

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함