티스토리 뷰
[Python] onnxruntime gpu(device) 에 업로드된 데이터 cpu(host)로 다운로드 안하고 바로 inference 하는 방법
developer0hye 2024. 4. 11. 23:56onnxruntime inference 예제를 찾아보면 거의 input은 cpu에서 pre-processing한 numpy array(on cpu)를 session.run 함수의 입력으로 주는 경우가 많습니다.
그치만 실제로는 pre-processing도 GPU에서 하고 이걸 굳이 cpu 로 내려서 입력하는 일은 없는 게 일반적일겁니다. GPU, CPU 업로드, 다운로드 횟수는 줄일 수 있으면 최대한 줄여야 하는 아주 악의 축 같은 작업입니다. 특히 input 사이즈가 큰데 GPU 업로드 했다 CPU로 다운로드 했다 하다보면 차라리 CPU로 구현하는 것만 못한 속도가 나올겁니다.
그래서, GPU에 있는 데이터를 바로 추론할 수 있어야 합니다! onnxruntime 은 당연히 이런 기능을 제공하고 있습니다. 공식 예제는 아래 링크를 참고하시면 됩니다.
https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
아래는 제가 구현한 예제입니다. io_binding을 쓰면 기존 흔히 보던 onnxruntime inference 코드와는 좀 차이점이 발생합니다. iobinding 객체를 생성해야하고 bind_input, bind_output 함수를 호출하면서 상세하게 input/output name, type, data pointer 등을 세팅하는 과정이 필요합니다. output 또한 iobinding 객체에서 꺼내오듯 코드를 구현해야합니다.
import torch
import torchvision
import numpy as np
import onnxruntime
model = torchvision.models.resnet18(weights=None)
model.eval()
input_on_cpu = torch.randn((1, 3, 224, 224), dtype=torch.float32)
torch.onnx.export(model, input_on_cpu, "resnet18.onnx", input_names=['input'], output_names=['output'])
session = onnxruntime.InferenceSession("resnet18.onnx", providers=['CUDAExecutionProvider'])
output = session.run(None, {'input': input_on_cpu.numpy()})[0]
print(output.sum())
input_on_gpu = input_on_cpu.to('cuda')
iobinding = session.io_binding()
iobinding.bind_input(name='input',
device_type='cuda',
device_id=0,
element_type=np.float32,
shape=input_on_gpu.shape,
buffer_ptr=input_on_gpu.data_ptr())
iobinding.bind_output(name='output', device_type='cuda', device_id=0)
session.run_with_iobinding(iobinding)
output = iobinding.copy_outputs_to_cpu()[0]
print(output.sum())
'Deep Learning' 카테고리의 다른 글
RT DETR의 백본은 무엇일까 (0) | 2024.04.16 |
---|---|
action detection 이랑 tracking 논문/코드 찾아보는 중 (0) | 2024.04.13 |
ConvNet vs ViT 비교 논문 (0) | 2024.02.29 |
20240222 YOLOv9 가 공개되다. (0) | 2024.02.22 |
META V-JEPA (0) | 2024.02.18 |
- Total
- Today
- Yesterday
- 조합
- 백트래킹
- ㅂ
- 자료구조
- 인공지능을 위한 선형대수
- FairMOT
- C++ Deploy
- PyCharm
- 가장 긴 증가하는 부분 수열
- 순열
- 이분탐색
- cosine
- 백준 11053
- LCA
- 문제집
- MOT
- 백준 11437
- 단축키
- 위상 정렬 알고리즘
- Lowest Common Ancestor
- 백준 1766
- 파이참
- 백준
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |