티스토리 뷰
PyTorch 개발자들에게 swin 사용법은 3가지 정도가 있을 거 같다.
1. Official swin project 갖다 쓰기
https://github.com/microsoft/Swin-Transformer
2. timm 거 갖다 쓰기
https://github.com/rwightman/pytorch-image-models
3. mmdetection 거 갖다 쓰기
https://github.com/open-mmlab/mmdetection
1, 2, 3 모두다 pytorch외에 라이브러리를 따로 설치하거나, 프로젝트를 따로 받아서 내 프로젝트에서 쓸 수 있게끔 필요한 파일, 함수만 골라쓰는등의 불편함이 존재한다.
timm의 swin의 경우 pretrained weight을 사용하는 경우 입력 이미지 사이즈가 사전 학습과정에서 쓰인 사이즈에 제한된다. 즉 거의 못쓴다고 봐야한다.
torchvision 에도 0.13 버전부터 swin을 지원하고 있다. 게다가 timm 처럼 pretrained weight을 사용하는 경우 이미지 사이즈에 종속되지 않는다.
https://pytorch.org/vision/0.13/search.html?q=swin&check_keywords=yes&area=default
그리고 torchvision은 웬만하면 torch 설치할때 같이 설치하는 라이브러리다 보니 라이브러리를 더 설치해야해! 하는 부담감도 적다.
사용법
내가 테스트한 torchvision 버전은 0.14.0+cu117 이다.
아래코드는 예제다.
import torchvision
import torch
model = torchvision.models.swin_t(weights=torchvision.models.Swin_T_Weights.DEFAULT)
model(torch.zeros(1, 3, 448, 224))
model = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)
model(torch.zeros(1, 3, 256, 128))
근데, torchvision에 있는 model 들이 대개 timm 처럼 finetuning에 초점을 맞춰서 친절히 구현돼있진 않다.
swin을 백본으로 가져다 쓰려는 사람들은 대체로 마지막에 global average pooling 되기 전, 피쳐맵을 얻어오고 싶을 경우가 대다수일 것이다.
그런 경우 아래처럼 코드를 작성하면 된다. avgpool, flatten, head 레이어를 그냥 Identity 레이어로 변경해주면된다. 그런다음 나온 피쳐맵을 갖고 잘 주물 주물 해주면된다.
import torchvision
import torch
import torch.nn as nn
model = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)
out = model(torch.zeros(1, 3, 256, 128))
print(out.shape) # [1, 1000]
model.avgpool = nn.Identity()
model.flatten = nn.Identity()
model.head = nn.Identity()
out = model(torch.zeros(1, 3, 256, 128))
print(out.shape) # [1, 768, 8, 4]
아 근데 나는 1/8, 1/16, 1/32 배 다운샘플링된 피쳐맵도 필요한데... 라고 한다면 코드를 좀 수정해야한다.
torchvision/models/swin_transformer.py 에서 SwinTransformer 클래스에 함수를 하나 추가해주자. 굳이 꼭 SwinTransformer 클래스에 함수를 추가할 필요는 없고 아래코드의 의미만 이해하고 적절히 잘 가져다쓰면 될듯하다.
def forward_featrues(self, x):
multiscale_features = []
for layer in self.features:
if isinstance(layer, PatchMergingV2) or isinstance(layer, PatchMerging):
multiscale_features.append(self.permute(x))
x = layer(x)
multiscale_features.append(self.permute(x))
return multiscale_features
관련해서 pr도 날리긴했는데 이건 머지는 안될 거 같다. 내가 코드 까먹으면 다시 볼려고 남긴 것도 있다 ㅋㅋ...
https://github.com/pytorch/vision/pull/6887
근데... 내 태스크/내 데이터셋에서 파인튜닝 해보니 성능이 영 안나왔다...
'Deep Learning > PyTorch' 카테고리의 다른 글
[PyTorch] Dataset을 Concat 할 수 있다고...? (1) | 2023.03.11 |
---|---|
TorchLightning 기반 프로젝트 모음 (0) | 2022.11.11 |
PyTorch Contribution! (0) | 2022.08.11 |
YOLOv5 Contribution! Weight Decay (2) | 2022.07.17 |
[PyTorch] AutoMixedPrecision 주의점 (0) | 2022.03.27 |
- Total
- Today
- Yesterday
- 이분탐색
- 위상 정렬 알고리즘
- 조합
- 순열
- 백준 11437
- ㅂ
- MOT
- 파이참
- FairMOT
- 백준
- 단축키
- PyCharm
- 백준 1766
- C++ Deploy
- 백트래킹
- 문제집
- 가장 긴 증가하는 부분 수열
- LCA
- 인공지능을 위한 선형대수
- Lowest Common Ancestor
- cosine
- 백준 11053
- 자료구조
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |