Table of Contents

Diffusion Modelを扱う際、生成された画像にチェッカーボードアーティファクトが現れることがあります。これらのアーティファクトは、Diffusion ModelのUNetアーキテクチャで使用されるアップサンプリング手法に起因することが多いです。

チェッカーボードアーティファクトの一般的な原因は、UNet内のアップサンプリングに転置畳み込み(デコンボリューションとも呼ばれる)を使用することです。転置畳み込みは出力ピクセルに不均等な重複を生じさせ、チェッカーボードに似た視覚的なパターンにつながる可能性があります。

これらのアーティファクトを軽減するには、以下のアプローチを検討してください:

  1. 代替アップサンプリング手法の使用:転置畳み込みを、最近傍補間に続いて標準的な畳み込みを行う方法や、バイリニア補間に続いて畳み込みを行う方法などの他のアップサンプリング技術に置き換えます。これらの手法は、チェッカーボードパターンを生じさせることなく、より滑らかな結果を生成する傾向があります。
  2. カーネルサイズとストライドの調整:転置畳み込みを使用する必要がある場合は、出力空間の均等なカバレッジを確保するためにカーネルサイズとストライドを慎重に選択してください。不均等な重複につながる設定を避けてください。
  3. 後処理技術:残留アーティファクトを滑らかにするために後処理フィルタを適用します。ガウシアンぼかしやメディアンフィルタリングなどの技術は、チェッカーボードパターンの可視性を低減するのに役立ちます。
import torch
import torch.nn as nn
import torch.nn.functional as F

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, method='nearest',
                 scale_factor=2, kernel_size=3, stride=2, padding=1):
        super(UpsampleBlock, self).__init__()
        self.method = method
        self.scale_factor = scale_factor

        if method in ['nearest', 'bilinear']:
            # 代替アップサンプリング:補間 + 畳み込み
            self.conv = nn.Conv2d(in_channels, out_channels,
                                 kernel_size=kernel_size,
                                 padding=kernel_size//2)
        elif method == 'transpose':
            # アーティファクトを避けるために慎重に選択されたパラメータを持つ転置畳み込み
            # デフォルト:kernel_size=4、stride=2、padding=1は均等なカバレッジを保証
            # これは以下を満たします:output_size = (input_size - 1) * stride - 2 * padding + kernel_size
            # 2倍アップサンプリングの場合:output = 2 * input
            self.conv_transpose = nn.ConvTranspose2d(
                in_channels, out_channels,
                kernel_size=kernel_size if method == 'transpose' else 4,
                stride=stride,
                padding=padding
            )
        else:
            raise ValueError(f"Unsupported upsampling method: {method}")

    def forward(self, x):
        if self.method == 'nearest':
            x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
            x = self.conv(x)
        elif self.method == 'bilinear':
            x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
            x = self.conv(x)
        elif self.method == 'transpose':
            x = self.conv_transpose(x)
        return x