(Python) Swin Transformer - Image Classification

Swin Transformer - Image Classification


1. get_started.md

image 연결된 MODEL HUB로 들어가서 사전 훈련된 모델들의 모음을 구경할 수 있음.

image 나는 siwn-V1의 swin-T를 사용하려 한다. ImageNet-1K로 사전훈련된 모델이다. 걸려진 링크의 log에서 훈련 과정 기록도 볼 수 있음.

2. Install

image 설치조건을 살펴보자

  • python version = 3.7
  • CUDA >= 10.2 (with cudnn >= 7)
  • PyTorch >= 1.8.0 and torchvision >= 0.9.8 (with CUDA >= 10.2)

2.1. pytorch docker

각 프로젝트마다 Python과 라이브러리들의 버전이 다른 경우 가상환경(python venv나 conda 등)을 사용하면 각 프로젝트의 환경을 분리해서 python의 여러 버전을 다룰 수 있다. 이때 사용하는 것이 Docker.

2.2. Clone

docker를 사용하면 좋지만 colab에서 시도해봄.
게시물 Colab에서 github 코드 사용하기 참조. image SwinTransformer 폴더 안에 github을 clone해 놓음.

2.3. 각종 version 확인

import torch
import torchvision

!python --version
print("Torch version:{}".format(torch.__version__))
print("cuda version: {}".format(torch.version.cuda))
print("cudnn version:{}".format(torch.backends.cudnn.version()))
print("Torchvision version:{}".format(torchvision.__version__))

image pytorch만 1.8이상으로 업데이트 해주면 되는 상황.

pip install timm==0.4.12
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

이렇게 설치를 다하면
image 런타임 다시 시작하라고 나옴.

3. 필요한 module import

import timm
import torchvision
from torchvision import transforms as T
import torch
from PIL import Image

4. Data준비

URL = "https://raw.githubusercontent.com/SharanSMenon/swin-transformer-hub/main/imagenet_labels.json" # Imagenet labels
!wget https://www.allaboutbirds.org/guide/assets/photo/306327661-480px.jpg -O house_finch.jpg

imageNet label jason파일이랑 새 사진 한 장 가져옴. house finch로 분류되는 사진임. 근데 이 이미지를 바로 사용할 수는 없고 전처리가 필요함.

trans_ = T.Compose([
                    T.Resize(256),
                    T.CenterCrop(224), # Model requries 224x224 images
                    T.ToTensor(), # Converts to pyTorch tensor
                    T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD) # swin transformer was traiened on normalized data,
                    # therefore we must normalize our images too
])

image = Image.open("house_finch.jpg")
transformed = trans_(image) # Convert to pytorch tensor
transformed.size() #이미지가 tensor로 전환됨. 
transformed.size()
>> torch.Size([3, 224, 224])

tensor로 전환된 거에 배치를 추가해줌.

batch = transformed.unsqueeze(0) # The model accepts a batch of image, so we create a batch of 1 image.
                                 # unsqueeze : 0번째 위치에 tensor 차원확장
>> batch.Size([1, 3, 224, 224]) #이제 model에 넣을 수 있다.

5. model 구현

timm의 model list에서 swin~이라는 이름을 가진 모델들의 리스트를 확인함.

timm.list_models("swin*", pretrained=True) # This will list all the swin transformer models available

image 나는 이 중에 사전 훈련된 swin-T(swin_tiny_patch4_window7_224)를 사용할 거임.

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
print(model)
SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(56, 56), dim=96
        (reduction): Linear(in_features=384, out_features=192, bias=False)
        (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): BasicLayer(
      dim=192, input_resolution=(28, 28), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(28, 28), dim=192
        (reduction): Linear(in_features=768, out_features=384, bias=False)
        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2): BasicLayer(
      dim=384, input_resolution=(14, 14), depth=6
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (2): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (3): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (4): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (5): SwinTransformerBlock(
          (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (downsample): PatchMerging(
        input_resolution=(14, 14), dim=384
        (reduction): Linear(in_features=1536, out_features=768, bias=False)
        (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      )
    )
    (3): BasicLayer(
      dim=768, input_resolution=(7, 7), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinTransformerBlock(
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): DropPath()
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (head): Linear(in_features=768, out_features=1000, bias=True)
)

모델 아키텍쳐를 확인해보면 Swintransformr 블럭으로 구성된 총 4개의 stage(basic layer)가 있음.

6. Test

with torch.no_grad():
  output = model(batch) # torch.no_grad() incrases speed by disabling gradients.
  # pass in our 'batch' to the model
output
tensor([[-1.9016e-01,  1.9889e-01,  3.0785e-01,  5.0095e-01,  6.4167e-01,
          7.0143e-01, -1.1234e-02,  3.5708e-01,  1.0447e+00,  4.7405e-01,
          4.1761e+00,  1.4396e+00,  8.2452e+00,  1.1930e+00,  1.5377e+00,
          1.2554e+00,  8.3897e-01, -3.7446e-01, -1.3489e+00,  7.0715e-01,
          9.3417e-01,  7.1980e-01,  2.0256e-01,  7.6432e-01,  1.0556e+00,
          4.1342e-01,  7.8627e-01,  1.8811e-01,  5.9584e-01,  5.2687e-01,
          3.7547e-01, -9.9655e-01,  7.5490e-01, -2.6306e-02,  4.7791e-01,
          3.0842e-01,  5.6430e-01, -4.9928e-01,  6.1437e-01,  2.5116e-01,
          3.1124e-01,  5.7473e-01,  4.7467e-01,  5.8698e-01,  5.4986e-01,
          2.4520e-01, -2.3827e-01,  1.2559e+00,  7.9930e-02,  3.1247e-01,
          6.1390e-01,  3.2108e-01,  1.0920e+00, -1.2128e-01,  9.9993e-02,
         -5.6232e-01,  2.9175e-01, -1.9913e-01,  2.2639e-02,  8.9008e-01,
         -8.3613e-02,  1.1607e-01,  1.6977e-01,  4.4581e-01,  1.9575e-01,
          3.0114e-01,  5.2494e-01, -3.3385e-02,  6.9526e-01,  2.3393e-01,
          2.1891e-01, -1.5875e-02, -1.4747e-01,  1.1835e+00, -6.7120e-01,
          4.7606e-01,  4.5420e-01, -4.2010e-01,  8.1708e-01, -4.4538e-02,
          1.2209e+00,  5.4841e-01,  6.9280e-01,  7.4113e-01, -1.7145e-01,
          6.5327e-01,  1.3126e+00,  6.8698e-02,  3.3557e-01,  8.9700e-02,
          1.8628e-01,  2.2029e-01,  7.8222e-01,  2.9551e-01,  6.1749e-01,
          1.1113e-01, -8.3187e-01, -2.3362e-01,  7.1607e-01,  2.7000e-01,
          7.0056e-02,  6.5012e-01, -1.1278e-01,  5.5624e-01, -7.2767e-02,
          5.2746e-01,  2.3039e-01,  3.1528e-01, -1.1438e-02, -1.8073e-01,
          7.9599e-01,  4.5522e-01,  1.8323e-01,  4.5729e-01, -1.4132e-01,
          3.9605e-01,  4.5185e-01,  8.1261e-01, -3.0709e-01, -1.9369e-03,
         -6.2317e-02, -1.9239e-01,  1.2727e-01, -1.6577e-01, -4.2616e-01,
         -4.5357e-01,  1.1885e+00,  1.1287e-01,  3.7298e-02,  6.0257e-01,
          1.8089e-01, -2.7116e-01,  3.7338e-01,  1.1794e+00,  6.2387e-02,
          2.3052e-01,  3.3541e-01, -4.4384e-03, -1.0064e+00,  6.8106e-01,
          1.7592e+00,  4.7248e-01,  5.7414e-01, -7.9369e-01,  4.0171e-01,
          2.4801e-02,  2.6289e-01,  8.7113e-02,  2.1482e-01, -2.2497e-01,
          5.4912e-01,  3.6707e-01,  1.9378e-02,  4.2420e-01,  2.5855e-01,
         -4.1984e-01,  3.0794e-01,  8.5788e-02, -1.0902e-01,  4.0053e-02,
          2.1857e-01,  1.4636e-02,  8.1957e-02, -2.2590e-01, -2.3231e-01,
          2.4350e-01, -4.9512e-02,  4.8668e-01, -2.2230e-01,  1.1205e-01,
          5.9115e-01, -2.4379e-01,  1.2021e-01,  1.0780e-01, -1.9410e-01,
         -6.2192e-02,  6.2526e-01,  8.1567e-01,  6.0292e-01,  4.0620e-01,
         -2.8007e-01,  4.8430e-01,  2.1697e-01,  3.3052e-01,  1.0079e-01,
          6.2138e-01,  5.8198e-01,  2.3499e-01,  7.4986e-01,  8.5305e-01,
          7.4808e-01, -3.1625e-03,  1.3825e-01,  1.5847e-01,  4.1359e-01,
          1.7468e-02, -4.0559e-01,  7.9949e-01,  5.6901e-02,  7.2078e-01,
          3.6567e-01, -2.5864e-01,  2.4481e-01,  1.9460e-01, -7.2914e-02,
          4.1781e-01,  7.1999e-01, -8.2583e-02,  2.0628e-01, -1.1075e-01,
          3.7352e-01,  4.6917e-02,  2.0226e-01,  2.6312e-01,  2.9837e-01,
         -3.1440e-02,  3.6304e-01,  9.2160e-02,  1.0028e-01, -1.6261e-01,
          2.4298e-01,  4.3064e-01,  6.2832e-01,  3.5762e-01,  8.8712e-01,
         -1.2376e-01,  1.5457e-01,  1.5564e-01,  9.0174e-01,  3.5670e-01,
         -2.5416e-01,  2.4155e-01,  3.5712e-01,  6.8735e-01,  3.7985e-01,
          5.6126e-03, -1.3895e-01, -1.2159e-01, -4.2863e-01,  1.8999e-01,
          2.3478e-01,  2.7616e-01, -1.7466e-01,  1.3764e-02,  2.8008e-01,
         -8.1785e-02, -1.0094e-01,  2.1708e-02, -3.0793e-01,  8.8364e-02,
          3.6706e-01, -5.3914e-01,  7.2833e-01, -2.6394e-01, -1.7443e-01,
          1.6182e-01,  7.7405e-01,  5.2431e-01,  2.7311e-01,  1.7658e-01,
          4.7212e-01,  3.5499e-01,  1.9276e-01, -6.9727e-02,  3.7494e-01,
          8.2984e-01, -1.9308e-02,  5.8385e-01,  3.5803e-01,  1.1516e+00,
          1.3227e+00,  1.1154e+00,  1.1374e+00,  5.9013e-01,  7.6104e-01,
          4.2783e-01,  6.6518e-01,  1.0178e+00,  3.0298e-01,  1.1665e+00,
          8.5852e-01,  7.0953e-01, -6.7987e-02,  4.2619e-01,  2.7169e-02,
          7.5299e-01,  1.1239e+00,  3.3798e-01,  3.9194e-01,  4.5476e-01,
         -1.2264e-01,  1.5644e-01,  1.5443e-02,  5.5933e-01,  4.5060e-01,
          1.0366e+00,  7.4167e-01,  1.1647e+00,  1.4232e-01,  2.2146e-01,
          2.4077e-01,  6.0314e-01,  5.8562e-01,  7.3030e-01,  2.1645e-01,
         -3.8634e-01,  3.9612e-01,  1.5674e-01,  2.8146e-01,  2.1685e-01,
          4.0272e-01,  5.1490e-01,  1.8958e-01,  3.7581e-01,  1.0959e-01,
          3.2805e-01,  5.0887e-02,  3.9491e-01,  3.3445e-02,  4.7147e-01,
          7.0775e-01,  1.7443e-01,  1.5620e+00,  2.8595e-01,  1.0009e+00,
          4.4089e-01,  1.0637e+00, -1.2779e-01,  1.1127e-01,  3.8324e-01,
          7.8786e-01,  5.8636e-01,  3.7145e-01, -1.0142e-01,  3.4647e-01,
          8.3113e-02,  1.5621e-01,  6.0177e-01,  1.7819e-01, -2.9716e-01,
          3.7846e-01,  4.4441e-01,  2.2138e-01, -1.9090e-01,  2.8334e-01,
          2.3959e-01,  3.0895e-01,  7.1432e-01,  1.7525e-01,  5.2820e-01,
          6.4451e-01,  3.8495e-01,  5.0214e-01,  1.8482e-01, -2.3971e-01,
         -2.3657e-01,  9.1942e-01,  1.4066e+00,  1.1907e+00,  2.4637e-01,
          6.3798e-01,  1.0244e+00,  1.1758e+00, -7.6327e-03,  1.2490e+00,
          1.2514e-01,  9.7387e-01,  7.3128e-01,  1.3924e+00,  1.3600e+00,
          1.4896e-01,  9.4429e-01,  4.5068e-02,  1.1212e+00,  1.8605e-01,
          4.5022e-01,  1.2625e+00,  5.3233e-01,  3.7362e-01,  7.6741e-01,
          8.9300e-01,  4.0013e-01, -1.2323e-01,  3.7272e-01,  5.7332e-01,
          3.3325e-01,  4.2860e-01,  1.1771e+00,  6.5505e-01,  2.6137e-01,
         -5.6197e-01,  1.2330e-01,  3.9434e-01, -6.2766e-01, -1.7922e-02,
          5.6302e-01, -3.4868e-01, -1.0203e-01, -1.1533e-01,  3.9456e-01,
         -7.3878e-02, -3.4227e-01, -6.0979e-02, -4.5315e-01, -2.9820e-01,
          1.3319e-01, -2.2808e-01, -7.5696e-01, -6.3929e-01, -4.8852e-01,
         -5.9017e-01, -3.4703e-02,  1.1454e-01,  5.8063e-02, -4.5922e-01,
         -5.9126e-01, -7.2966e-01, -1.6626e-01, -1.6055e-01, -2.6499e-01,
         -1.2013e-01, -1.2224e-01, -3.6532e-01, -5.2364e-01, -4.4169e-01,
         -2.2341e-02, -5.8387e-02,  9.2884e-02, -2.5580e-01, -2.7870e-01,
         -9.0193e-01,  1.4743e-01,  1.8272e-01, -3.6225e-01,  7.2989e-02,
          1.2155e-02, -4.0540e-01, -3.4014e-01,  1.6540e-02,  1.7883e-01,
         -6.6057e-01, -7.4604e-01,  3.3700e-01,  6.1531e-01, -4.8099e-01,
         -5.6576e-01, -6.3594e-01, -7.3908e-01,  4.5265e-01, -6.9846e-01,
         -2.7634e-01,  4.3537e-01,  7.5218e-02, -1.9517e-01, -1.0576e+00,
         -5.7038e-01, -4.9136e-02,  1.0134e-01, -5.1726e-01,  7.3077e-01,
         -6.1172e-01,  9.1187e-02,  4.7658e-01,  4.2102e-01,  6.2047e-01,
          4.4012e-02, -4.6376e-01, -4.5354e-01, -5.5197e-01,  4.2342e-01,
          6.6338e-01, -7.8547e-01, -1.5879e-01, -6.8415e-01, -5.4500e-01,
         -1.9469e-01, -6.7532e-01, -3.6719e-01, -3.4788e-01, -8.8570e-01,
         -4.7493e-01, -5.0555e-01, -4.9782e-01, -4.3000e-01, -5.1897e-01,
         -9.2349e-01, -2.3556e-01, -1.9488e-01,  9.4515e-01,  1.1331e-01,
          8.7227e-02,  2.2582e-02, -8.5214e-02,  1.2412e-01,  3.8071e-01,
         -1.7089e-01, -2.1744e-01, -3.6068e-01, -7.2279e-01,  1.5909e-01,
         -1.3547e-01,  5.5003e-01, -1.2745e-01,  1.6347e-01,  5.3833e-02,
          5.9629e-01,  6.2980e-01, -7.8568e-01, -5.3709e-01, -5.3084e-01,
         -3.3307e-01, -7.1869e-01, -2.6294e-01, -7.4432e-01, -3.3571e-01,
         -3.0180e-01,  1.3953e-01, -1.3105e-01, -1.6679e-01, -5.8059e-01,
         -4.0118e-01, -4.0796e-01, -2.3450e-01, -3.7842e-01,  2.3672e-01,
         -5.2231e-01, -1.9563e-02, -2.5553e-01, -5.0407e-01, -1.2448e-01,
         -1.4636e-01, -7.0233e-01, -3.0430e-01,  1.8414e-01, -3.1405e-01,
         -5.3228e-01, -7.3455e-01, -2.9528e-01, -1.4198e-01, -8.1183e-03,
          1.3087e-01, -3.8116e-01, -1.8200e-01, -3.9521e-01, -2.0075e-01,
         -5.9787e-01, -4.2323e-01, -4.4982e-01, -5.2729e-01, -5.1305e-02,
         -1.2284e-01, -2.2798e-01, -5.5582e-01, -1.4420e-01, -6.3921e-01,
         -3.6248e-01,  2.3947e-02, -1.0276e-01,  1.4596e-01,  9.9152e-02,
         -2.9461e-01, -5.2575e-01,  2.2913e-02, -1.7880e-01, -3.9180e-01,
         -6.4333e-02, -2.7347e-01, -2.3845e-01,  2.2865e-01, -7.7296e-01,
         -3.8672e-02, -5.3160e-01,  2.4262e-01, -5.0585e-01,  7.9512e-01,
         -5.4792e-01, -3.5188e-01, -1.6749e-03,  5.7811e-01, -1.0483e-01,
          2.9546e-01, -9.9502e-01, -5.9201e-01, -1.9802e-01,  1.7315e+00,
         -9.1269e-01, -6.9820e-01,  1.1115e-01, -3.3711e-01, -3.0991e-02,
         -4.2238e-01,  5.0731e-01, -6.5463e-01, -6.5915e-01, -1.2187e-01,
         -3.6533e-01,  7.7418e-02, -1.2348e-01, -2.0200e-01,  6.2378e-01,
          3.3811e-01,  5.7402e-01, -2.7601e-01, -6.3712e-01, -1.5360e-01,
         -6.1983e-01, -2.5320e-01, -1.6812e-01,  4.3989e-02, -7.6879e-01,
          5.0545e-01,  3.9386e-02, -6.1623e-01, -4.1230e-01, -4.3297e-02,
         -3.6549e-01,  5.5354e-01,  6.5465e-02,  2.6579e-01,  2.3514e-01,
         -2.1405e-01, -8.1598e-01, -7.6677e-01,  1.7969e-01, -7.5569e-01,
         -5.7384e-01, -1.6247e-01, -7.2302e-01, -2.2339e-01,  2.6205e-02,
         -6.6659e-01, -6.8074e-01, -4.3061e-01,  1.0212e-01, -5.0308e-01,
         -3.8882e-01, -1.8231e-01, -3.7277e-01,  1.8769e-01, -2.1780e-01,
         -4.6009e-01, -5.6345e-01, -3.1182e-01,  2.4363e-01, -1.9799e-02,
         -3.2701e-01, -1.1985e-01, -1.0040e-01, -1.6500e-01, -5.7042e-02,
         -4.3935e-01, -3.0716e-01, -1.2662e-01,  5.3732e-01, -6.1773e-01,
          2.2006e-01, -3.8895e-01, -7.0388e-01, -9.6321e-02,  9.3200e-01,
         -2.4885e-01, -9.9200e-01, -5.6791e-01, -9.2939e-01,  1.1409e-01,
         -5.3784e-01,  6.5583e-01,  1.7979e-01, -5.9499e-01, -3.8889e-01,
         -4.0091e-01, -5.5202e-01, -6.6172e-01, -2.8358e-01,  1.3915e-01,
         -2.9885e-01,  4.0023e-01,  5.6519e-01, -3.1619e-02,  1.1973e+00,
         -2.3130e-01, -5.0424e-01, -1.4633e-01, -3.6287e-01, -3.7308e-02,
         -4.4300e-01, -6.4421e-01, -6.4729e-01, -2.2907e-01,  5.4190e-01,
         -3.4271e-01, -2.1642e-01, -7.2845e-01, -2.8447e-01, -5.7509e-01,
         -3.5847e-01, -1.5119e-01, -1.6043e-01, -2.4737e-03, -6.4167e-01,
          2.9207e-01, -2.5146e-03, -7.9427e-01,  1.9586e-01, -1.5453e-01,
         -4.0220e-01, -7.9659e-01, -6.0393e-01,  3.2927e-01, -6.0892e-01,
         -5.2895e-01,  1.0301e-01,  2.3389e-01, -4.2930e-01, -4.5528e-01,
          2.2037e-01, -2.1419e-01, -7.9580e-01, -5.3957e-01, -5.2205e-03,
         -5.0920e-01, -1.6658e-01, -2.6721e-01,  4.2789e-02, -3.3528e-01,
          7.5573e-01, -1.9973e-01, -1.5124e-01, -1.4544e-01,  1.2947e-02,
         -8.9012e-02, -1.6414e-01, -8.4208e-01, -6.3775e-02, -7.4057e-01,
          2.3504e-01, -7.8663e-01, -8.2063e-01,  1.3774e+00, -1.6367e-01,
         -9.9647e-01, -6.3561e-01, -4.6837e-01, -4.4509e-01, -1.7012e-01,
         -5.2070e-01, -5.3334e-01, -4.1822e-01,  2.5975e-01,  5.3876e-01,
         -4.7741e-01, -2.6495e-01, -5.4020e-01, -3.3319e-01, -4.0689e-01,
         -1.8129e-01,  3.9623e-01, -8.1908e-01,  8.5798e-02, -3.8031e-01,
         -2.6029e-02, -7.3476e-01, -1.3109e-01, -3.1236e-01, -4.4590e-01,
          2.8202e-01, -1.8046e-01, -1.2541e-01, -2.6163e-01, -2.8035e-01,
          5.1533e-02, -2.8707e-01,  1.8908e-01, -3.2279e-02, -1.4058e-01,
         -2.9554e-01,  4.2201e-02,  4.0757e-01, -5.5805e-01, -9.3672e-01,
         -2.2159e-01, -5.9347e-01, -5.2154e-01,  6.6839e-01,  2.8101e-02,
         -2.3748e-01,  6.8123e-02,  1.2480e-01, -5.1256e-01,  9.1172e-03,
         -1.0238e-02, -5.8544e-02,  5.1971e-01,  3.3778e-01,  1.3275e-01,
         -5.9229e-02, -5.1294e-01, -4.1083e-01, -4.6636e-02, -4.2335e-01,
         -1.2473e+00, -3.7039e-01, -2.0056e-02, -2.0438e-01,  2.3793e-02,
         -5.2832e-01, -3.1301e-01, -9.1200e-02, -9.5909e-02,  7.5334e-02,
         -2.6696e-01, -7.4490e-01, -1.6476e-01, -4.8951e-03, -6.0641e-01,
         -5.9582e-02, -2.6576e-02, -4.5913e-01,  1.1998e-01, -2.3965e-01,
         -7.2982e-01, -2.3401e-01, -2.9764e-01,  5.6456e-02,  8.9344e-01,
          2.8097e-01, -7.1086e-01, -1.1443e+00, -1.8808e-01, -5.4427e-01,
         -3.2140e-01,  5.1490e-02, -8.2018e-01, -1.6115e-01, -4.9608e-02,
          2.7496e-01, -1.7680e-01,  2.3059e-01, -5.1096e-01, -3.2908e-01,
          5.3168e-01, -5.8806e-01, -3.3796e-01, -1.4284e-01, -3.4888e-01,
         -4.8914e-01,  2.6956e-01, -8.2569e-01, -5.0551e-01, -5.4816e-02,
          5.7785e-01, -8.5647e-02, -3.3454e-01,  2.1260e-01, -2.8888e-01,
          1.8412e-01, -5.0054e-01, -5.7881e-01, -3.6718e-01, -7.9511e-01,
         -2.8605e-01,  8.7803e-02,  1.9759e-02, -4.7048e-01, -4.2033e-01,
         -6.2413e-01, -4.3208e-01, -2.5880e-01,  1.3739e-01, -2.8092e-02,
         -7.0037e-01, -2.1985e-01, -3.1013e-01, -4.5317e-01, -5.9855e-01,
         -2.1277e-01,  2.8163e-02, -2.7099e-01, -8.3828e-01, -3.0602e-01,
         -2.8236e-01, -8.1016e-01, -5.7811e-01,  4.1197e-01, -1.9368e-01,
          6.2221e-02, -5.8416e-01,  3.4758e-02, -6.7568e-01, -5.2671e-01,
         -5.7511e-01, -2.6696e-01, -7.0816e-02, -3.4019e-01, -3.0295e-02,
         -1.1176e-01, -3.5160e-01, -5.5518e-01, -6.5310e-01,  3.0679e-01,
          1.2405e-01, -2.5834e-01,  2.0472e-01,  3.8454e-01,  3.6679e-01,
         -1.8764e-01, -1.9986e-01, -2.8072e-01, -4.5951e-01, -1.6553e-01,
         -1.7848e-01,  3.6908e-01,  3.1331e-02, -7.5528e-02, -2.0375e-02,
         -2.0628e-01, -5.5666e-01, -8.6257e-01, -2.3624e-01, -6.9168e-01,
         -1.2105e-01, -4.1813e-01, -9.2024e-01,  1.8456e-02,  3.8384e-01,
         -4.1279e-01, -4.0707e-01, -4.5136e-02, -3.4236e-01, -2.3643e-01,
         -7.0221e-01, -2.1239e-01, -3.6989e-01, -1.8822e-01, -2.9873e-01,
         -2.3639e-01,  1.2757e+00,  1.3328e+00,  5.4690e-01, -1.7743e-02,
          1.1688e-01, -2.3275e-01,  2.0169e-01, -2.2757e-01, -2.2208e-01,
         -6.0521e-01,  2.8892e-01,  6.2703e-01, -4.3468e-01, -2.8957e-01,
          3.4508e-01,  5.1537e-02,  6.0912e-01, -3.3402e-01, -5.1515e-01,
         -5.4792e-01,  6.3807e-01, -2.1738e-01, -4.3179e-01, -2.8177e-01,
         -6.2700e-01, -4.1343e-01, -2.1286e-01, -3.4088e-01,  3.2639e-02,
         -1.1227e-01,  5.0254e-01, -2.2646e-01,  2.8809e-01,  4.3122e-01,
         -1.9679e-01, -4.5594e-01, -2.0170e-01,  2.7217e-02,  1.7588e-01,
          2.9440e-01, -1.8112e-01, -1.7842e-01, -3.1087e-01, -1.5393e-01,
          2.4738e-01, -5.0606e-01, -7.0970e-01, -4.8142e-01, -2.4103e-02,
          1.4307e-01, -2.1021e-01, -4.3767e-01,  2.3254e-01,  7.9012e-01,
         -1.9629e-01,  9.4204e-01, -5.9398e-02,  1.2561e+00,  1.3093e+00,
          8.0038e-01,  1.6374e+00,  6.1325e-01,  1.7617e-01, -4.9355e-01]])

output은 이런 형태로 나오는데, 여기서 내가 제시한 이미지일 확률이 가장 높은 class가 큰 숫자를 가짐.

class_ = output.argmax(dim=1) # argmax(dim=1) : 1행에서 최댓값 출력
class_
>> tensor([12])

12번째 class가 정답일 확률이 가장 높은 class인 것. 이 때, 시작은 0부터여서 실제로는 13번째 위치에 있는 class임!

7. Result

import json
from urllib.request import urlopen

response = urlopen(URL) # remember i defined URL in the first cell
classes = json.loads(response.read())
classes[class_]
>> 'house finch'

jason목록에서 해당 위치의 index를 출력하면 정답이 잘 나옴!