티스토리 뷰
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 16, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(16)
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = x + torch.relu(self.bn2(self.conv2(x)))
return x
input = torch.randn((1, 3, 736, 1280))
model = Model()
traced_model = torch.jit.trace(model , input.cpu())
traced_model .save("test.ts")
위와 같이 코드를 작성하고 모델을 TorchScript를 사용하여 저장했습니다. 그러고나서 C++상에서 libtorch와 TRTorch를 이용하여 모델을 로드하고 추론해보니 모델의 Output 값의 오차가 크게 발생하였습니다.
정확히는 TRTorch로 모델을 TensorRT화 한 이후에 오차가 크게 발생하였습니다. 코드를 보았을때 의심이 갔던 부분은 input을 cpu로 처리한것과 eval을 호출하지 않아 trace과정에 실제 추론 과정과 다르게 모델이 trace가 되었을수 있겠다는 생각이 들었습니다.
해결 방법은 아래와 같습니다.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 16, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(16)
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = x + torch.relu(self.bn2(self.conv2(x)))
return x
model = Model()
model.eval()
traced_model = torch.jit.trace(model, input)
traced_model.save("test.ts")
모델을 trace하기전에 eval을 호출해주고 trace한 다음에 모델을 저장하였습니다. 이렇게 코드를 수정한 결과 C++ 상에서 모델을 로드하고 추론하였을때 오차가 발생하지 않았습니다.
'Deep Learning > PyTorch' 카테고리의 다른 글
[PyTorch] Mish 메모리 이슈 (0) | 2021.02.26 |
---|---|
[PyTorch & LibTorch] OpenCV Image to Normalized Tensor C++ 코드 (0) | 2021.02.18 |
[PyTorch] conv layer default initialization method (0) | 2021.02.12 |
[PyTorch] torch.set_deterministic(True) 에러가 떠요! (0) | 2021.01.28 |
[PyTorch] auto mixed precision(amp) 사용시 주의할 점 (0) | 2021.01.27 |
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- 백준
- FairMOT
- 자료구조
- Lowest Common Ancestor
- PyCharm
- 백트래킹
- 순열
- ㅂ
- 백준 11053
- C++ Deploy
- 인공지능을 위한 선형대수
- 문제집
- MOT
- 가장 긴 증가하는 부분 수열
- 백준 1766
- 위상 정렬 알고리즘
- cosine
- 단축키
- 백준 11437
- 조합
- 이분탐색
- 파이참
- LCA
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
글 보관함