画像認識の現場でtimm (PyTorch Image Models) は非常に便利です。timm.create_model('vit_base_patch16_224', pretrained=True)のコード一つで最先端モデルを利用できますが、内部のテンソル形状の変化まで自信を持って説明できるでしょうか。
製造業の検査システムなど、実用的な精度と速度の両立が求められる現場において、新アーキテクチャ採用時は「スクラッチ実装」による検証が推奨されます。ブラックボックスのままでは、精度低下や推論エラー時の原因特定が困難になるためです。また、Hugging Face Transformersの最新アップデートでTensorFlowやFlaxのサポートが終了し、PyTorch中心のモジュール型設計へ移行しました。この業界の流れからも、PyTorchベースのアーキテクチャ理解はAIエンジニアの必須スキルです。
本記事では、Vision Transformer (ViT) をPyTorchのみでゼロから実装します。アルゴリズムの原理となる数式をコードへ翻訳し、データ処理やAttention(注目)の集まり方を詳細に可視化します。エッジ推論では計算速度の観点からCNNが依然強力ですが、ここではフィルターによる局所的な特徴抽出というCNNの常識を脇に置き、ViTの新しい世界へ踏み出します。データから仮説を立て、実装と実験で検証するサイクルを通じて、モデルの挙動を深く理解していきましょう。
1. なぜ今、CNNではなくViTを「実装」して学ぶのか
画像認識はConvolutional Neural Networks (CNN) の時代が長く続きましたが、2020年の論文「An Image is Worth 16x16 Words」以降、状況は変化しました。ViTが注目され、実装して学ぶべき理由を解説します。
帰納的バイアス(Inductive Bias)の違いを理解する
CNNとViTの最大の違いは「帰納的バイアス」(未知データ予測時に使う学習アルゴリズムの仮定)の量です。
- CNN: 「画像の特徴は局所的(Locality)」「特徴は位置に依存しない(Translation Invariance)」という強いバイアスを持ち、少データでも効率よく学習できます。
- ViT: 画像特有のバイアスを極力排除し、Self-Attention機構で離れたパッチ同士の関係性(大域的な特徴)をデータから直接学習します。
ViTはバイアスが少ない分、学習に大量のデータ(または強力な正則化)が必要ですが、データ量が増えればCNNの性能限界を超えてスケールするポテンシャルを持ちます。実装を通じてこの「関係性を学ぶ」メカニズムを体感し、精度向上の仮説検証に役立てることが本記事の狙いです。
ライブラリ利用の限界とスクラッチ実装の効能
ライブラリは便利ですが、入力チャンネル数の変更や特殊なAttentionマスクの適用、マルチモーダル化などのカスタマイズには内部実装の理解が不可欠です。
スクラッチ実装には以下の効能があります。
- デバッグ能力の向上: エラー発生時、どの層の形状不一致が原因か即座に特定できる。
- 論文読解力の向上: 数式とコードの対応関係が頭に入り、最新論文の実装が容易になる。
- 応用力: ViTベースの派生モデル(Swin Transformer, MAEなど)への理解が早まる。
本チュートリアルのゴール:ViT-Baseの再現
今回は最も標準的な ViT-Base モデルをターゲットにします。入力画像サイズは224x224、パッチサイズは16x16を想定し、これらは変数として定義して変更可能な設計にします。
2. 実装環境の準備とデータセットの前処理
まずは環境を整えます。特別なライブラリは極力避けますが、テンソル操作を直感的に記述でき、Transformer実装のデファクトスタンダードになりつつある einops は、可読性向上のためあえて使用します。
PyTorch環境のセットアップと必要ライブラリ
以下のライブラリを使用します。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import matplotlib.pyplot as plt
import numpy as np
# バージョン確認(参考)
# print(torch.__version__) # 例: 2.0.1
einops が未インストールの場合は pip install einops で導入してください。
画像データのテンソル化とパッチ分割の概念
ViTは画像をそのまま受け取らず、自然言語処理の単語(トークン)分割と同様に、小さな「パッチ」の列として扱います。
例えば、224x224の画像を16x16のパッチに分割する場合:
- 横方向: $224 / 16 = 14$ 個
- 縦方向: $224 / 16 = 14$ 個
- 合計パッチ数 ($N$): $14 \times 14 = 196$ 個
各パッチは $16 \times 16$ ピクセルでRGBの3チャンネルを持つため、1パッチの情報量は $16 \times 16 \times 3 = 768$ 次元です。
つまり、ViTへの入力は [Batch, 196, 768] という形状のテンソルになります。
DataLoaderの構築とAugmentation戦略
動作確認用にはCIFAR-10を使用します。ViTはImageNet等の大規模データセットで真価を発揮しますが、スクラッチ実装の検証には軽量データセットで十分です。ただし、入力サイズはViTの標準に合わせてリサイズします。
# ハイパーパラメータ設定
BATCH_SIZE = 64
IMAGE_SIZE = 224 # ViT標準
PATCH_SIZE = 16
NUM_CLASSES = 10
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# データセットのダウンロード(初回のみ)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
# データの形状確認
data_iter = iter(trainloader)
images, labels = next(data_iter)
print(f"Input Shape: {images.shape}") # torch.Size([64, 3, 224, 224])
3. Patch Embedding:画像を「言葉」に変換する
ここからが実装の本番です。画像をパッチに分割し、埋め込みベクトル(Embedding)に変換する層を作成します。
Conv2dを使った効率的なパッチ分割実装
パッチ分割を for 文で行うのは非効率です。通常は、カーネルサイズとストライドをパッチサイズと同じに設定した畳み込み層(Conv2d)を通し、パッチ分割と線形射影(Linear Projection)を同時に行います。これによりGPUの並列処理能力を最大限に活かし、処理速度を向上させます。
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
super().__init__()
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Conv2dでパッチ分割と射影を同時に行う
# kernel_size=16, stride=16 にすることで、重ならないパッチを生成
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e h w -> b (h w) e') # [B, C, H, W] -> [B, N, E]
)
def forward(self, x):
# x: [Batch, 3, 224, 224]
x = self.projection(x)
# x: [Batch, 196, 768]
return x
ここで Rearrange が活躍します。Conv2dの出力 [Batch, 768, 14, 14] を、Transformer入力用に [Batch, 14*14, 768](つまり [Batch, 196, 768])へ変形しています。
CLSトークンの追加とその役割
ViTでは、画像全体の特徴を集約して分類するため、特殊な「CLS(Classification)トークン」をシーケンスの先頭に追加します。これは自然言語処理モデルBERTの設計思想を踏襲したものです。
Transformerのエンコーダー出力のうち、このCLSトークンに対応するベクトルのみを最終的な分類ヘッド(MLP Head)に入力し、他のパッチトークンの出力は直接的な分類には使用しません。
学習可能なPositional Embeddingの実装
Transformerは構造上、入力パッチの順序(位置関係)を認識できず、左上も右下のパッチも対等に扱われます。そのため、位置情報を表すベクトル(Positional Embedding)の加算が必要です。
ViTでは固定のサイン波ではなく、学習可能なパラメータとして位置埋め込みを定義するのが一般的です。
これらをまとめた完全なEmbedding層の実装は以下の通りです。
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
super().__init__()
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e h w -> b (h w) e')
)
# CLSトークン (学習可能パラメータ)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
# Positional Embedding (学習可能パラメータ)
# パッチ数 + CLSトークン分の長さが必要
self.positions = nn.Parameter(torch.randn(1, self.n_patches + 1, emb_size))
def forward(self, x):
b, _, _, _ = x.shape
x = self.projection(x) # [B, 196, 768]
# バッチサイズ分だけCLSトークンを複製
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
# 先頭にCLSトークンを結合
x = torch.cat([cls_tokens, x], dim=1) # [B, 197, 768]
# 位置情報を加算
x += self.positions
return x
4. Multi-Head Self-Attention (MSA) の核心を記述する
ViTの心臓部であるMulti-Head Self-Attentionです。ここでは入力パッチ(トークン)同士が「お互いにどれくらい関連しているか」を計算します。
Query, Key, Valueの生成と次元数
Attention機構では、入力ベクトル $x$ から全結合層(Linear)を用いて Query ($Q$), Key ($K$), Value ($V$) の3つのベクトルを生成します。
Scaled Dot-Product Attentionの数式とコードの対応
Attentionの計算式は以下の通りです。
$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $
これをコードに落とし込みます。マルチヘッド化のため、埋め込み次元 emb_size を num_heads で分割します。
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size=768, num_heads=12, dropout=0.):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.head_dim = emb_size // num_heads
# Q, K, V を一度に生成して分割する効率的な実装
self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x):
# x: [Batch, N, Emb_size]
# 1. Q, K, V の生成と変形
# [B, N, 3*Emb] -> 3つに分割 -> [B, N, Heads, Head_Dim]
qkv = rearrange(self.qkv(x), 'b n (h d qkv) -> (qkv) b h n d', h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]
# 2. Attention Scoreの計算 (Scaled Dot-Product)
# Q * K^T / sqrt(d)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # [B, Heads, N, N]
scaling = self.emb_size ** (1/2)
att = torch.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
# 3. Valueとの積
# Att * V
out = torch.einsum('bhal, bhlv -> bhav', att, values) # [B, Heads, N, Head_Dim]
# 4. ヘッドの結合
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.projection(out)
return out
マルチヘッド化による表現力の向上
num_heads=12 の場合、768次元のベクトルを64次元×12個に分割して処理します。これにより、あるヘッドは「色」、別のヘッドは「形状」の関係性など、異なる視点の特徴を同時に捉えることが可能になります。
5. Encoder BlockとMLPの積層
Attention層だけでは非線形性が不足するため、MLP(Multi-Layer Perceptron)と組み合わせて一つのEncoder Blockを構成します。
Layer Normalizationと残差接続(Skip Connection)の実装
ViTでは、各層の 前 にNormalizationを適用する「Pre-Norm」構成が一般的です。また、勾配消失を防ぐため、入出力を足し合わせる残差接続(Skip Connection)が必須です。これにより、深いネットワークでも安定した学習が可能になります。
class FeedForward(nn.Sequential):
def __init__(self, emb_size, expansion=4, drop_p=0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
nn.Dropout(drop_p)
)
class TransformerEncoderBlock(nn.Module):
def __init__(self, emb_size=768, num_heads=12, forward_expansion=4, drop_p=0.):
super().__init__()
self.norm1 = nn.LayerNorm(emb_size)
self.mha = MultiHeadAttention(emb_size, num_heads, drop_p)
self.norm2 = nn.LayerNorm(emb_size)
self.ff = FeedForward(emb_size, forward_expansion, drop_p)
self.dropout = nn.Dropout(drop_p)
def forward(self, x):
# Skip Connection 1
x = x + self.dropout(self.mha(self.norm1(x)))
# Skip Connection 2
x = x + self.dropout(self.ff(self.norm2(x)))
return x
Transformer Encoderのクラス化とスタッキング
最後に、このブロックを指定回数(ViT-Baseでは12回)積み重ねます。
class TransformerEncoder(nn.Sequential):
def __init__(self, depth=12, **kwargs):
super().__init__(*[TransformerEncoderBlock(kwargs) for _ in range(depth)])
6. 全体結合と小規模学習実験
すべてのパーツが揃いました。これらを統合して ViT クラスを完成させます。
ViTクラスの完成とMLP Headの実装
最終的な出力は、CLSトークン(インデックス0)の特徴量を取り出し、LayerNormを通した後に分類用のLinear層へ入力します。
class ViT(nn.Module):
def __init__(self,
in_channels=3,
patch_size=16,
emb_size=768,
img_size=224,
depth=12,
num_classes=10,
kwargs):
super().__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
self.transformer_encoder = TransformerEncoder(depth, emb_size=emb_size, **kwargs)
self.mlp_head = nn.Sequential(
nn.LayerNorm(emb_size),
nn.Linear(emb_size, num_classes)
)
def forward(self, x):
x = self.patch_embedding(x)
x = self.transformer_encoder(x)
# CLSトークンのみを使用
x = self.mlp_head(x[:, 0])
return x
CIFAR-10を用いた学習ループの構築
モデルをインスタンス化し、学習ループを回します。GPU利用時はモデルを .cuda() で転送してください。
# モデルのインスタンス化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# 簡易学習ループ(1エポックのみ例示)
model.train()
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f'Batch {i}, Loss: {loss.item():.4f}')
実行するとLossが徐々に下がる様子が確認でき、スクラッチ実装が正しく機能していることがわかります。実験結果から仮説を検証する第一歩です。
Attention Mapの可視化による動作確認
学習が進むと、モデルは画像の重要部分に「注目」します。MultiHeadAttention クラス内の att 変数(Softmax後の重み)を取り出すことで、モデルの注目箇所をヒートマップとして可視化できます。
具体的には、CLSトークンと他パッチのAttentionスコアを抽出し、元の画像サイズ(14x14パッチ)にリシェイプして表示します。これにより、「犬」の分類時に「耳」や「顔」に注目が集まっているかを確認でき、モデルの説明性(Explainability)向上に役立ちます。以下のコードは、抽出したAttentionスコアを元画像に重ねて可視化する実装例です。
import matplotlib.pyplot as plt
import cv2
import numpy as np
def visualize_attention(image, attention_weights):
# attention_weights: (num_heads, num_patches+1, num_patches+1) のテンソルを想定
# 最初のヘッドのCLSトークン(インデックス0)から他パッチ(インデックス1以降)へのAttentionを抽出
cls_attention = attention_weights[0, 0, 1:].detach().cpu().numpy()
# 14x14のパッチグリッドにリシェイプ (画像サイズ224 / パッチサイズ16 = 14)
attention_map = cls_attention.reshape(14, 14)
# 元画像サイズ(224x224)にリサイズして滑らかにする
attention_map = cv2.resize(attention_map, (224, 224))
# 0〜1の範囲に正規化
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# 可視化のセットアップ
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# 元画像の表示
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis("off")
# Attention Mapをヒートマップとして重ねて表示
axes[1].imshow(image)
axes[1].imshow(attention_map, cmap="jet", alpha=0.5)
axes[1].set_title("Attention Map")
axes[1].axis("off")
plt.show()
# ※ 実際の使用時は、モデルの推論時に取得した attention_weights と対象の画像を渡して実行します
# visualize_attention(original_image, attention_weights)
まとめ:実装から見えてくるAIの「視線」
今回はライブラリを使わず、PyTorchでVision Transformerをスクラッチ実装しました。以下の点がクリアになったはずです。
- 画像も「言葉」と同じ: Patch Embeddingにより、画像データは単なるベクトルの列として扱われる。
- 帰納的バイアスの排除: CNNのような畳み込み操作なしに、Self-Attentionだけで大域的な特徴を捉えている。
- 計算コストの正体: Attentionの計算量はパッチ数 $N$ の二乗 $O(N^2)$ に比例します。例えば、解像度を2倍にするとパッチ数は4倍になり、Attentionの計算量は16倍に跳ね上がります。エッジ推論など速度が求められる環境や高解像度画像への適用には、精度とスピードのトレードオフを考慮した工夫が必要です。
この実装をベースにすれば、パッチサイズやEncoderの深さを変える実験も自由自在です。ブラックボックスを恐れず、中身を理解した上でツールを使いこなすことが重要です。
コメント