Table of Contents

When working with latent diffusion models like Stable Diffusion, issues in the generated images can often be traced back to problems in the VAE (Variational Autoencoder) component. This script allows you to isolate and debug the VAE by encoding an input image to latent space and then decoding it back to pixel space. You can confirm whether the VAE is functioning correctly if the input and output images closely match.

git clone https://gist.github.com/7fe430bc5640a2dafa9f9814f9c9b8d9.git

or see below:

import argparse
import os
from typing import Tuple, Optional

import numpy as np
import torch
from PIL import Image
from diffusers import AutoencoderKL


def pick_device_and_dtype() -> Tuple[torch.device, torch.dtype]:
    if torch.cuda.is_available():
        return torch.device("cuda"), torch.float16
    # If you're on Apple Silicon and want to use MPS, uncomment below:
    # if torch.backends.mps.is_available():
    #     return torch.device("mps"), torch.float16
    return torch.device("cpu"), torch.float32


def load_vae(repo_id: str, device: torch.device, dtype: torch.dtype, hf_token: Optional[str] = None) -> AutoencoderKL:
    vae = AutoencoderKL.from_pretrained(
        repo_id,
        subfolder="vae",
        torch_dtype=dtype,
        token=hf_token,  # If None, it will use cached creds or anonymous (will fail if auth is required)
    )
    vae.to(device)
    vae.eval()
    return vae


def ensure_multiple_of_8(size: Tuple[int, int]) -> Tuple[int, int]:
    w, h = size
    w8 = (w // 8) * 8
    h8 = (h // 8) * 8
    # Avoid zero
    w8 = max(8, w8)
    h8 = max(8, h8)
    return w8, h8


def image_to_tensor(
    img: Image.Image,
    target_size: Optional[Tuple[int, int]],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Convert PIL Image -> normalized tensor in [-1, 1], shape (1, 3, H, W).
    Resizes to target_size (W, H) if provided; otherwise, rounds down to multiples of 8.
    """
    img = img.convert("RGB")
    if target_size is None:
        # Make dimensions multiples of 8 (required for SD VAE)
        target_size = ensure_multiple_of_8(img.size)
    img = img.resize(target_size, Image.BICUBIC)

    np_img = np.array(img).astype(np.float32) / 255.0  # (H, W, 3) in [0, 1]
    torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0)  # (1, 3, H, W)
    torch_img = torch_img.to(device=device, dtype=dtype)
    torch_img = (torch_img * 2.0) - 1.0  # Normalize to [-1, 1]
    return torch_img


def encode_to_latents(vae: AutoencoderKL, tensor_img: torch.Tensor, sample: bool, seed: Optional[int] = None) -> torch.Tensor:
    """
    Encode image tensor in [-1, 1] to latent space.
    Returns latents scaled by vae.config.scaling_factor, shape (B, 4, H/8, W/8).
    """
    if seed is not None:
        g = torch.Generator(device=tensor_img.device).manual_seed(seed)
    else:
        g = None

    with torch.no_grad():
        posterior = vae.encode(tensor_img).latent_dist
        latents = posterior.sample(generator=g) if sample else posterior.mean
        latents = latents * vae.config.scaling_factor
    return latents


def decode_from_latents(vae: AutoencoderKL, latents: torch.Tensor) -> torch.Tensor:
    """
    Decode scaled latents back to image tensor in [-1, 1], shape (B, 3, H, W).
    """
    with torch.no_grad():
        scaled = latents / vae.config.scaling_factor
        decoded = vae.decode(scaled).sample
    return decoded


def tensor_to_image(tensor_img: torch.Tensor) -> Image.Image:
    """
    Convert tensor in [-1, 1] to PIL Image (uint8).
    Expects shape (1, 3, H, W) or (3, H, W).
    """
    if tensor_img.dim() == 4:
        tensor_img = tensor_img[0]
    tensor_img = (tensor_img / 2 + 0.5).clamp(0, 1)  # to [0, 1]
    np_img = (tensor_img.detach().cpu().permute(1, 2, 0).numpy() * 255).round().astype(np.uint8)
    return Image.fromarray(np_img)


def main():
    parser = argparse.ArgumentParser(description="VAE encode/decode with stable diffusion using diffusers.")
    parser.add_argument("--image", type=str, required=True, help="Path to input image.")
    parser.add_argument("--out", type=str, default="decoded.png", help="Path to save the decoded image.")
    parser.add_argument("--width", type=int, default=None, help="Target width (must be multiple of 8). If omitted, rounded down.")
    parser.add_argument("--height", type=int, default=None, help="Target height (must be multiple of 8). If omitted, rounded down.")
    parser.add_argument("--sample", action="store_true", help="Sample from posterior instead of using mean (stochastic).")
    parser.add_argument("--seed", type=int, default=None, help="Random seed used when --sample is set.")
    parser.add_argument("--repo", type=str, default="CompVis/stable-diffusion-v1-4", help="HF repo id for the model.")
    parser.add_argument("--hf_token_env", type=str, default="HF_TOKEN", help="Env var name for HF token if required.")
    args = parser.parse_args()

    device, dtype = pick_device_and_dtype()
    hf_token = os.environ.get(args.hf_token_env, None)

    # Load only the VAE from the SD v1-4 checkpoint
    vae = load_vae(args.repo, device=device, dtype=dtype, hf_token=hf_token)

    # Load and prepare image tensor
    pil_img = Image.open(args.image)
    target_size = None
    if args.width is not None and args.height is not None:
        if args.width % 8 != 0 or args.height % 8 != 0:
            raise ValueError("width and height must be multiples of 8.")
        target_size = (args.width, args.height)  # (W, H)
    tensor_img = image_to_tensor(pil_img, target_size, device=device, dtype=dtype)

    # Encode to latents
    latents = encode_to_latents(vae, tensor_img, sample=args.sample, seed=args.seed)
    print(f"Encoded latents shape: {tuple(latents.shape)} (scaled by {vae.config.scaling_factor})")

    # Decode back to image tensor
    decoded_tensor = decode_from_latents(vae, latents)
    print(f"Decoded tensor shape: {tuple(decoded_tensor.shape)} (value range ~[-1, 1])")

    # Convert to PIL and save
    decoded_img = tensor_to_image(decoded_tensor)
    decoded_img.save(args.out)
    print(f"Saved decoded image to: {args.out}")


if __name__ == "__main__":
    main()

Usage example:

python debug_diffusion_vae.py --image input.png --out out.png --width 512 --height 512 --sample --seed 42 --repo CompVis/stable-diffusion-v1-4