티스토리 뷰
[ONNX, PyTorch, TensorRT] GatherElements 지원 이슈
developer0hye 2021. 6. 30. 22:32PyTorch 로 구현한 모델에 torch.gather 가 포함된 경우 ONNX 변환 시 ONNX 모델에 GatherElements 연산이 포함될 수 있다.
onnx-tensorrt 프로젝트내에 Supported Operators에 관한 문서를 보면 GatherElements를 지원하나, TensorRT 버전에 따라 지원이 될 수도 있고 안될수도 있다. 적어도 7.2.1.6 버전에서는 지원이 안되는 것을 확인하였다. (ONNX2TensorRT 과정은 내가 맡은 부분이 아니라 확실하진 않다...)
https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md
onnx/onnx-tensorrt
ONNX-TensorRT: TensorRT backend for ONNX. Contribute to onnx/onnx-tensorrt development by creating an account on GitHub.
github.com
8. 대 버전부터 지원되나보다. 근데 아직(20210630 기준) 8. 대는 Early Access버전이다.
https://docs.nvidia.com/deeplearning/tensorrt/release-notes/tensorrt-8.html
Release Notes :: NVIDIA Deep Learning TensorRT Documentation
This is the TensorRT 8.0.0 Early Access (EA) release notes and is applicable to Linux x86 users. These release notes are applicable to workstation, server, and JetPack users unless appended specifically with (not applicable for Jetson platforms). This rele
docs.nvidia.com

텐서알티 버전업을 하기 쉬운 환경이라면 버전업하면 해결될 문제이지만... 그게 어렵다고 하면 PyTorch2ONNX 과정에서 어떻게든 되게끔 만들어야한다... 구글링을 해보니 해결 방법이 있었다.
https://ask.csdn.net/questions/1480565
No importer registered for op: GatherElements-开源项目-CSDN问答
Encountered the same problem today. [ERROR] INVALID_ARGUMENT: getPluginCreator could not find plugin GatherElements version 1 It looks like for now the only option is to rewrite the code to avoid using it (I'm porting pytorch -> onnx first and only then on
ask.csdn.net
기존에 torch.gather를 사용한 부분에서 torch.만 없애고 아래의 함수를 거치도록 코드를 수정하면 된다.
다만, 결과는 꼭 확인해야한다. 배치 데이터를 넣었을때 모든 데이터가 제대로 처리되는지 꼭 확인해야한다.
def gather(input, dim, index):
indices = [torch.arange(size, dtype=torch.int32, device=index.device) for size in index.shape]
indices = list(torch.meshgrid(*indices))
indices[dim] = index
sizes = list(reversed(list(itertools.accumulate(reversed(input.shape), operator.mul))))
index = sum((index * size for index, size in zip(indices, sizes[1:] + [1])))
output = input.flatten()[index]
return output
'Deep Learning > PyTorch' 카테고리의 다른 글
| [PyTorch Lightning] 관련 링크 모음 (0) | 2021.08.31 |
|---|---|
| YOLOv4, Scaled YOLOv4 를 구현함에 있어 유의할 점 (4) | 2021.08.10 |
| [PyTorch] AMP Loss dtype Issue (0) | 2021.06.23 |
| [PyTorch] DistributedDataParallel 예시 코드 및 참고 자료 모음 (4) | 2021.06.18 |
| [PyTorch] timm(rwightman/pytorch-image-models) 백본 스테이지별 채널 수 확인 코드 (0) | 2021.06.18 |
- Total
- Today
- Yesterday
- 순열
- Lowest Common Ancestor
- 백트래킹
- 파이참
- cosine
- 단축키
- ㅂ
- 이분탐색
- 가장 긴 증가하는 부분 수열
- PyCharm
- 문제집
- C++ Deploy
- 조합
- 백준
- 백준 11053
- 위상 정렬 알고리즘
- MOT
- 자료구조
- LCA
- 인공지능을 위한 선형대수
- FairMOT
- 백준 11437
- 백준 1766
| 일 | 월 | 화 | 수 | 목 | 금 | 토 |
|---|---|---|---|---|---|---|
| 1 | 2 | 3 | 4 | 5 | 6 | |
| 7 | 8 | 9 | 10 | 11 | 12 | 13 |
| 14 | 15 | 16 | 17 | 18 | 19 | 20 |
| 21 | 22 | 23 | 24 | 25 | 26 | 27 |
| 28 | 29 | 30 | 31 |
