티스토리 뷰

Deep Learning/PyTorch

[PyTorch] Swin 모델 사용법

developer0hye 2022. 11. 2. 18:50

PyTorch 개발자들에게 swin 사용법은 3가지 정도가 있을 거 같다.

 

1. Official swin project 갖다 쓰기

 

https://github.com/microsoft/Swin-Transformer

 

GitHub - microsoft/Swin-Transformer: This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer u

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". - GitHub - microsoft/Swin-Transformer: This is an official implementation...

github.com

 

2. timm 거 갖다 쓰기

https://github.com/rwightman/pytorch-image-models

 

GitHub - rwightman/pytorch-image-models: PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, Eff

PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more - GitHub - rwig...

github.com

 

3. mmdetection 거 갖다 쓰기

https://github.com/open-mmlab/mmdetection

 

GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark

OpenMMLab Detection Toolbox and Benchmark. Contribute to open-mmlab/mmdetection development by creating an account on GitHub.

github.com

 

1, 2, 3 모두다 pytorch외에 라이브러리를 따로 설치하거나, 프로젝트를 따로 받아서 내 프로젝트에서 쓸 수 있게끔 필요한 파일, 함수만 골라쓰는등의 불편함이 존재한다.

 

timm의 swin의 경우 pretrained weight을 사용하는 경우 입력 이미지 사이즈가 사전 학습과정에서 쓰인 사이즈에 제한된다. 즉 거의 못쓴다고 봐야한다.

 

torchvision 에도 0.13 버전부터 swin을 지원하고 있다. 게다가 timm 처럼 pretrained weight을 사용하는 경우 이미지 사이즈에 종속되지 않는다.

 

https://pytorch.org/vision/0.13/search.html?q=swin&check_keywords=yes&area=default 

 

Search — Torchvision 0.13 documentation

Shortcuts

pytorch.org

 

그리고 torchvision은 웬만하면 torch 설치할때 같이 설치하는 라이브러리다 보니 라이브러리를 더 설치해야해! 하는 부담감도 적다.

 

사용법

내가 테스트한 torchvision 버전은 0.14.0+cu117 이다.

 

아래코드는 예제다.

import torchvision
import torch

model = torchvision.models.swin_t(weights=torchvision.models.Swin_T_Weights.DEFAULT)

model(torch.zeros(1, 3, 448, 224))

model = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)

model(torch.zeros(1, 3, 256, 128))

근데, torchvision에 있는 model 들이 대개 timm 처럼 finetuning에 초점을 맞춰서 친절히 구현돼있진 않다.

 

swin을 백본으로 가져다 쓰려는 사람들은 대체로 마지막에 global average pooling 되기 전, 피쳐맵을 얻어오고 싶을 경우가 대다수일 것이다. 

 

그런 경우 아래처럼 코드를 작성하면 된다. avgpool, flatten, head 레이어를 그냥 Identity 레이어로 변경해주면된다. 그런다음 나온 피쳐맵을 갖고 잘 주물 주물 해주면된다.

 

import torchvision
import torch
import torch.nn as nn

model = torchvision.models.swin_v2_t(weights=torchvision.models.Swin_V2_T_Weights.DEFAULT)

out = model(torch.zeros(1, 3, 256, 128))
print(out.shape) # [1, 1000]

model.avgpool = nn.Identity()
model.flatten = nn.Identity()
model.head = nn.Identity()

out = model(torch.zeros(1, 3, 256, 128))
print(out.shape) # [1, 768, 8, 4]

 

아 근데 나는 1/8, 1/16, 1/32 배 다운샘플링된 피쳐맵도 필요한데... 라고 한다면 코드를 좀 수정해야한다.

 

torchvision/models/swin_transformer.py 에서 SwinTransformer 클래스에 함수를 하나 추가해주자. 굳이 꼭 SwinTransformer 클래스에 함수를 추가할 필요는 없고 아래코드의 의미만 이해하고 적절히 잘 가져다쓰면 될듯하다.

 

def forward_featrues(self, x):
    multiscale_features = []
    for layer in self.features:
        if isinstance(layer, PatchMergingV2) or isinstance(layer, PatchMerging):
            multiscale_features.append(self.permute(x))
        x = layer(x)
    multiscale_features.append(self.permute(x))
    return multiscale_features

 

관련해서 pr도 날리긴했는데 이건 머지는 안될 거 같다. 내가 코드 까먹으면 다시 볼려고 남긴 것도 있다 ㅋㅋ...

 

https://github.com/pytorch/vision/pull/6887

 

add a function for multi-scale feature extraction of swin models by developer0hye · Pull Request #6887 · pytorch/vision

It is related to #6886.

github.com

 

근데... 내 태스크/내 데이터셋에서 파인튜닝 해보니 성능이 영 안나왔다...

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함