티스토리 뷰
[PyTorch] DistributedDataParallel 예시 코드 및 참고 자료 모음
developer0hye 2021. 6. 18. 18:12
기존에 Single node, multiple GPUs System(그냥 PC 1대에 GPU 여러대 꽂힌 피시로 생각, pytorch 공식 문서에서 이렇게 표기했길래 따라 씀) 에서 multiple gpu 를 활용함에 있어 DataParallel 모듈을 활용했다. 그런데, 쓰다보니 GPU 메모리 불균형 문제가 너무 심해서 메모리가 쏠리는 한 GPU 가 터지지 않게 해당 GPU의 메모리 사용량에 배치사이즈를 맞추다보니 다수개의 GPU를 제대로 활용하지 못했었다.
이와 관련해서는 아래의 포스트에서도 언급되어 있다.
그래서 이참에 DDP(DistributedDataParallel) 모듈을 사용해보려 했는데 DP(DataParallel) 처럼 모듈로 모델만 감싸주면 되는 줄 알았는데 그게 아니였다...
world_size 니 rank 니 생소한 용어들을 이해하지 않고 코드만 복붙해서 쓰려다보니 계속 오류가 났고 실행은 됐는데 Gpu 별로 메모리 분배도 고르게 되지 않았다... (world_size 는 학습에 필요한 총 프로세스 수이고 Single node, multiple GPUs System 에서는 사용할 GPU의 개수로 생각하면 된다. rank = 그냥 프로세스별 아이디로 생각하면 면될듯)
(현재 시스템이 Single node, multiple GPUs System 이라는 가정하에)
어떤 함수 def training(...): 에서 data load, forward, backward, step(optimization) 과정이 다 일어난다고 해보자 일단 DP는 Single Process 에서 forward 과정에서만 GPU가 병렬적으로 사용되지만, DDP는 아예 이 training(...) 함수가 Multiple Process 에서 동작되어야 한다. 이 차이가 가장 큰 차이였다. 이걸 이해하지 않은채 Single Process 에서 model 을 DDP로 감싸고 사용 가능한 GPU id를 gpu_inds 에 다 때려박으니 그냥 DP를 사용할때와 같이 동작하는 것 이였다...
결국 아래 pytorch 공식 예제 소스코드를 참고하여 구현하는데 성공하였다.
pytorch에서 제공하는 multiprocessing 패키지를 이용하고 병렬로 호출되는 main_worker 함수에서 프로세스 ID를 gpu 라는 변수에 저장하게끔 구현하고 한 프로세스에는 프로세스 ID와 동일한 GPU를 사용하도록 하는것이 관건이였다.
https://github.com/pytorch/examples/blob/master/imagenet/main.py
밑에 건 페이스북에서 제안한 DEIT 트랜스포머에 관한 프로젝트이다. 이 프로젝트에서도 DDP를 사용하고 있다. 이 프로젝트를 기반으로하는 트랜스포머 프로젝트들이 많으니 어느정도 신뢰하고 이 프로젝트의 main.py 코드를 참고해도 될듯하다.
https://github.com/facebookresearch/deit
torchrun 혹은 python -m torch.distributed.launch --use_env --nproc_per_node=* *.py 로 파일을 실행해보고 아래 부분에서 값을 출력해보고
https://github.com/facebookresearch/deit/blob/main/utils.py#L218-L220
하나의 프로세스에서만 print를 허용케 하는 아래 부분을 주석 처리(분석할때만)한다음
https://github.com/facebookresearch/deit/blob/main/utils.py#L238
파일을 실행시켜보면서 중간중간 print로 값을 찍다보면 감을 잡을 수 있을 것이다.
그리고 dataloader 에서 데이터 load 하기전에 sampler의 에폭을 꼭꼭 세팅해주어야 한다.
https://github.com/facebookresearch/deit/blob/main/main.py#L373
또 한가지 신경쓸부분은 model을 DP나 DDP로 감싸고나면, model의 인스턴스 변수와 함수를 접근할때 .module을 붙여야하고 이러한 model의 weight(=state_dict)를 저장하게 되면 key 값에 "module"이 붙게된다. 이러면 또 다음번에 모델 load할때 후처리 과정이 들어가는데 deit 는 이러한 후처리를 피하는 대신 변수를 아예 따로 하나 더 만들었다.
https://github.com/facebookresearch/deit/blob/main/main.py#L297-L300
아 그리고 seed를 프로세스마다 달리 할당한다.
https://github.com/facebookresearch/deit/blob/main/main.py#L182-L184
이유는 데이터 어그먼테이션이 프로세스별로 달리 적용되기 위함이라고 한다.
*직접 이슈를 작성했다. 이슈 관리가 굉장히 잘되는 프로젝트라 궁금한게 있으면 직접 개발자들한테 질문을 남기면 수일내로 답변을 얻을 수 있을 것이다.
https://github.com/facebookresearch/deit/issues/150
사족으로 deit 프로젝트에 기여도 했다.
https://github.com/facebookresearch/deit/pull/118
그리고, 또 중요한건데 모델이 (1) 배치노말라이제이션 레이어를 포함하고 DDP 학습시에 (2) 각 gpu에 할당되는 배치사이즈가 작다면, torch.nn.SyncBatchNorm.convert_sync_batchnorm 를 사용하여 각 gpu 간 배치노말라이제이션을 위한 통계치(평균, 분산)를 계산할때 모든 gpu에 load된 피쳐맵의 통계치를 구하도록 만들어주는 것이다. 이거 안해주면 gpu별로 평균, 분산 따로 구하는데 배치노말라이제이션 기법은 배치사이즈가 적으면 배치 데이터별로 평균, 분산이 매번 크게 달라지기 때문에 학습이 잘 안될 수도 있다. yolov5에도 이를 위한 처리 코드가 따로 있다.
https://github.com/ultralytics/yolov5/blob/master/train.py#L219
그외 공식 튜토리얼 링크
https://pytorch.org/tutorials/beginner/dist_overview.html
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
아예 테크니컬 리포트? 논문?도 있다.
PyTorch Distributed: Experiences on Accelerating Data Parallel Training
http://www.vldb.org/pvldb/vol13/p3005-li.pdf
아 그리고 DDP는 데이터로더단에서 sampler 도 전용 sampler를 설정해줘야한다.
이를 통해 process간에 dataset에서 data를 load할때 서로 중복이 안되는 데이터를 load 하도록 동작하는 것 같다.
실제로 한 process에서 iteration이 일어나는 횟수는 len(dataset)/(batch_size_per_gpu * num_multi_process) 로 생각하면 될 거 같다.
흠... 역시 원큐에 되는 건 없나보다.
2021-06-21
Detection 모델을 학습시키기 위해 DDP를 써봤는데 쓰니까 AP가 약 3정도 저하됐다. 굉장히 많이 저하된건데 이유를 모르겠다. SyncBN을 적용시켜봐야겠다.
2022-03-03
회사의 오픈프로젝트에 ddp 를 이용한 학습 코드를 올려놓았다.
https://github.com/MarkAny-Vision-AI/CenterNeXt
https://github.com/MarkAny-Vision-AI/CenterNeXt/commit/e17e4cf60f20b62a55c8ef90af2e6d88
위 커밋의 변경사항을 분석해보면 감이 올 것이다.
'Deep Learning > PyTorch' 카테고리의 다른 글
[ONNX, PyTorch, TensorRT] GatherElements 지원 이슈 (0) | 2021.06.30 |
---|---|
[PyTorch] AMP Loss dtype Issue (0) | 2021.06.23 |
[PyTorch] timm(rwightman/pytorch-image-models) 백본 스테이지별 채널 수 확인 코드 (0) | 2021.06.18 |
[PyTorch] Gradient accumulation 예제 코드 (0) | 2021.06.16 |
PyTorch to Onnx to TensorRT 과정을 위해 참고할만한 링크들 (0) | 2021.05.30 |
- Total
- Today
- Yesterday
- 백준 11437
- 파이참
- 자료구조
- 조합
- 백준
- LCA
- PyCharm
- ㅂ
- 백준 11053
- 위상 정렬 알고리즘
- 순열
- 단축키
- FairMOT
- Lowest Common Ancestor
- 인공지능을 위한 선형대수
- MOT
- 백트래킹
- 가장 긴 증가하는 부분 수열
- 문제집
- cosine
- 백준 1766
- 이분탐색
- C++ Deploy
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |