Table of Contents

Stable Diffusionのようなlatent diffusionモデルで作業する際、生成された画像の問題は、多くの場合VAE(Variational Autoencoder)コンポーネントの問題に起因することがあります。このスクリプトを使用すると、入力画像をlatent空間にエンコードしてからピクセル空間にデコードすることで、VAEを分離してデバッグできます。入力画像と出力画像が密接に一致する場合、VAEが正しく機能していることを確認できます。

git clone https://gist.github.com/7fe430bc5640a2dafa9f9814f9c9b8d9.git

または以下を参照:

import argparse
import os
from typing import Tuple, Optional

import numpy as np
import torch
from PIL import Image
from diffusers import AutoencoderKL


def pick_device_and_dtype() -> Tuple[torch.device, torch.dtype]:
    if torch.cuda.is_available():
        return torch.device("cuda"), torch.float16
    # If you're on Apple Silicon and want to use MPS, uncomment below:
    # if torch.backends.mps.is_available():
    #     return torch.device("mps"), torch.float16
    return torch.device("cpu"), torch.float32


def load_vae(repo_id: str, device: torch.device, dtype: torch.dtype, hf_token: Optional[str] = None) -> AutoencoderKL:
    vae = AutoencoderKL.from_pretrained(
        repo_id,
        subfolder="vae",
        torch_dtype=dtype,
        token=hf_token,  # If None, it will use cached creds or anonymous (will fail if auth is required)
    )
    vae.to(device)
    vae.eval()
    return vae


def ensure_multiple_of_8(size: Tuple[int, int]) -> Tuple[int, int]:
    w, h = size
    w8 = (w // 8) * 8
    h8 = (h // 8) * 8
    # Avoid zero
    w8 = max(8, w8)
    h8 = max(8, h8)
    return w8, h8


def image_to_tensor(
    img: Image.Image,
    target_size: Optional[Tuple[int, int]],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Convert PIL Image -> normalized tensor in [-1, 1], shape (1, 3, H, W).
    Resizes to target_size (W, H) if provided; otherwise, rounds down to multiples of 8.
    """
    img = img.convert("RGB")
    if target_size is None:
        # Make dimensions multiples of 8 (required for SD VAE)
        target_size = ensure_multiple_of_8(img.size)
    img = img.resize(target_size, Image.BICUBIC)

    np_img = np.array(img).astype(np.float32) / 255.0  # (H, W, 3) in [0, 1]
    torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0)  # (1, 3, H, W)
    torch_img = torch_img.to(device=device, dtype=dtype)
    torch_img = (torch_img * 2.0) - 1.0  # Normalize to [-1, 1]
    return torch_img


def encode_to_latents(vae: AutoencoderKL, tensor_img: torch.Tensor, sample: bool, seed: Optional[int] = None) -> torch.Tensor:
    """
    Encode image tensor in [-1, 1] to latent space.
    Returns latents scaled by vae.config.scaling_factor, shape (B, 4, H/8, W/8).
    """
    if seed is not None:
        g = torch.Generator(device=tensor_img.device).manual_seed(seed)
    else:
        g = None

    with torch.no_grad():
        posterior = vae.encode(tensor_img).latent_dist
        latents = posterior.sample(generator=g) if sample else posterior.mean
        latents = latents * vae.config.scaling_factor
    return latents


def decode_from_latents(vae: AutoencoderKL, latents: torch.Tensor) -> torch.Tensor:
    """
    Decode scaled latents back to image tensor in [-1, 1], shape (B, 3, H, W).
    """
    with torch.no_grad():
        scaled = latents / vae.config.scaling_factor
        decoded = vae.decode(scaled).sample
    return decoded


def tensor_to_image(tensor_img: torch.Tensor) -> Image.Image:
    """
    Convert tensor in [-1, 1] to PIL Image (uint8).
    Expects shape (1, 3, H, W) or (3, H, W).
    """
    if tensor_img.dim() == 4:
        tensor_img = tensor_img[0]
    tensor_img = (tensor_img / 2 + 0.5).clamp(0, 1)  # to [0, 1]
    np_img = (tensor_img.detach().cpu().permute(1, 2, 0).numpy() * 255).round().astype(np.uint8)
    return Image.fromarray(np_img)


def main():
    parser = argparse.ArgumentParser(description="VAE encode/decode with stable diffusion using diffusers.")
    parser.add_argument("--image", type=str, required=True, help="Path to input image.")
    parser.add_argument("--out", type=str, default="decoded.png", help="Path to save the decoded image.")
    parser.add_argument("--width", type=int, default=None, help="Target width (must be multiple of 8). If omitted, rounded down.")
    parser.add_argument("--height", type=int, default=None, help="Target height (must be multiple of 8). If omitted, rounded down.")
    parser.add_argument("--sample", action="store_true", help="Sample from posterior instead of using mean (stochastic).")
    parser.add_argument("--seed", type=int, default=None, help="Random seed used when --sample is set.")
    parser.add_argument("--repo", type=str, default="CompVis/stable-diffusion-v1-4", help="HF repo id for the model.")
    parser.add_argument("--hf_token_env", type=str, default="HF_TOKEN", help="Env var name for HF token if required.")
    args = parser.parse_args()

    device, dtype = pick_device_and_dtype()
    hf_token = os.environ.get(args.hf_token_env, None)

    # Load only the VAE from the SD v1-4 checkpoint
    vae = load_vae(args.repo, device=device, dtype=dtype, hf_token=hf_token)

    # Load and prepare image tensor
    pil_img = Image.open(args.image)
    target_size = None
    if args.width is not None and args.height is not None:
        if args.width % 8 != 0 or args.height % 8 != 0:
            raise ValueError("width and height must be multiples of 8.")
        target_size = (args.width, args.height)  # (W, H)
    tensor_img = image_to_tensor(pil_img, target_size, device=device, dtype=dtype)

    # Encode to latents
    latents = encode_to_latents(vae, tensor_img, sample=args.sample, seed=args.seed)
    print(f"Encoded latents shape: {tuple(latents.shape)} (scaled by {vae.config.scaling_factor})")

    # Decode back to image tensor
    decoded_tensor = decode_from_latents(vae, latents)
    print(f"Decoded tensor shape: {tuple(decoded_tensor.shape)} (value range ~[-1, 1])")

    # Convert to PIL and save
    decoded_img = tensor_to_image(decoded_tensor)
    decoded_img.save(args.out)
    print(f"Saved decoded image to: {args.out}")


if __name__ == "__main__":
    main()

使用例:

python debug_diffusion_vae.py --image input.png --out out.png --width 512 --height 512 --sample --seed 42 --repo CompVis/stable-diffusion-v1-4