티스토리 뷰

PyTorch 모델을 ONNX 형식으로 변환할 때, 모델의 forward 메소드가 추가 인자를 받는 경우가 있습니다. 예를 들어, 학습/추론 모드를 구분하는 phase 인자와 같은 것들이죠. 이런 경우 ONNX 변환 시 어떻게 처리해야 하는지 알아보겠습니다.

 

문제 상황 

일반적으로 PyTorch 모델의 forward 메소드는 입력 텐서만 받지만, 때로는 다음과 같이 추가 인자를 받도록 구현되어 있을 수 있습니다:

def forward(self, x, phase='train'):
    # phase에 따라 다른 동작을 수행
    if phase == 'train':
        # 학습 시의 동작
        ...
    else:
        # 추론 시의 동작
        ...
    return output

해결 방법

ONNX 변환 시 이러한 추가 인자를 전달하려면 torch.onnx.export 함수를 다음과 같이 사용해야 합니다:

 

# 1. 더미 입력 준비
dummy_input = torch.randn(1, 3, 112, 112)

# 3. ONNX 변환
torch.onnx.export(
    net,                     # 모델
    (dummy_input, 'test'),   # 입력 튜플: (입력 텐서, phase 인자)
    output_path,             # 출력 파일 경로
    input_names=['input', 'phase'],  # 입력 이름들
    output_names=['output'],         # 출력 이름
    dynamic_axes={
        'input': {0: 'batch_size'},  # 동적 배치 크기
    },
    opset_version=11,               # ONNX 버전
    do_constant_folding=True        # 최적화
)

 

입력을 튜플로(dummy_input, 'test') 넣어주면 됩니다.

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함