Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer

Github : https://github.com/NVlabs/Sana

 

GitHub - NVlabs/Sana: SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer

SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer - NVlabs/Sana

github.com

 

Sana는 엔비디아에서 발표한 4096 × 4096 해상도의 이미지를 효율적으로 생성할 수 있는 텍스트-이미지 프레임워크로 0.6B 및 1.6B 의 작은 파라미터 크기를 가지고 있어 노트북 GPU를 활용하여 배포가 가능하다. 

 

Sana 0.6B 모델의 경우 최대 4096 × 4096 해상도의 이미지를 생성할 수 있도록 학습되었으며 Auto Encoder 및 Auto Encoder의 압축 비율을 32로 설정하고 Linear DiT (Document Image Transformer)를 활용하여 훈련 및 추론 비용을 줄였다고 한다.

해당 모델의 아키텍처는 위와 같은데, 단계별로 보면
1) 이미지 데이터 준비

입력 데이터: 고해상도 이미지와 그에 맞는 텍스트(프롬프트) 쌍을 학습 데이터로 사용.

 

2) 이미지 압축(AutoEncoder)

(1) Encoder 

- 고해상도 이미지를 AutoEncoder의 인코더에 통과시켜 압축된 잠재 표현(latent representation)으로 변환.

- 원본 이미지 정보를 32배로 압축하여 계산 효율을 높임. 

 

(2) Decoder

- 학습 시, 생성된 잠재 표현을 디코더에 통과시켜 원본 이미지로 복원.

- 복원된 이미지와 원본 이미지 간의 차이를 계산(재구성 손실, Reconstruction Loss).

 

3) 텍스트 프롬프트 처리(SLM)

(1) 텍스트 프롬프트를 소형 언어 모델(Small LLM : gemma, Qwen)에 입력하여 세부화된 강화된 프롬프트(Enhanced Prompt)를 생성.

(2) 이 프롬프트는 이미지 생성 과정을 안내하는 핵심 정보로 사용됨.

 

4) 생성 모델 학습 (Linear DiT)

(1) 프롬프트 + 잠재 표현 결합:

- 압축된 이미지 잠재 표현과 강화된 텍스트 프롬프트를 함께 처리.

- 시간 임베딩(Time Embedding)을 추가하여 생성 과정의 단계별 정보를 제공.

 

(2) Linear Attention & Mix-FFN:

- Linear DiT 모듈을 통해 잠재 표현을 처리.

- 선형 어텐션과 Mix-FFN(Fully Connected Network)으로 효율적인 학습과 특성 추출 수행.

 

5) 디코딩 및 이미지 복원

(1) Linear DiT에서 최종 생성된 잠재 표현을 AutoEncoder의 디코더에 통과시켜 고해상도 이미지를 복원.

(2) 복원된 이미지가 입력 텍스트 프롬프트를 최대한 반영하도록 학습.

 

6) 손실 함수 계산 및 최적화

(1) 재구성 손실 (Reconstruction Loss): 원본 이미지와 디코딩된 이미지 간의 차이를 계산.

(2) 텍스트-이미지 일치 손실: 입력 텍스트 프롬프트와 생성된 이미지 간의 일치도를 평가.

(3) 최적화: 위 두 손실 값을 최소화하도록 모델의 가중치를 업데이트.

SANA 학습 코드 

# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 라이선스에 따라 소프트웨어를 "있는 그대로" 제공하며,
# 명시적이거나 묵시적인 어떠한 보증도 제공하지 않습니다.
#
# SPDX-License-Identifier: Apache-2.0

import datetime
import getpass
import hashlib
import json
import os
import os.path as osp
import random
import time
import types
import warnings
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pyrallis
import torch
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import DistributedType
from PIL import Image
from termcolor import colored

# 모든 경고 메시지 무시
warnings.filterwarnings("ignore")

# 커스텀 모듈 임포트
from diffusion import DPMS, FlowEuler, Scheduler
from diffusion.data.builder import build_dataloader, build_dataset
from diffusion.data.wids import DistributedRangedSampler
from diffusion.model.builder import (
    build_model,
    get_tokenizer_and_text_encoder,
    get_vae,
    vae_decode,
    vae_encode,
)
from diffusion.model.respace import compute_density_for_timestep_sampling
from diffusion.utils.checkpoint import load_checkpoint, save_checkpoint
from diffusion.utils.config import SanaConfig
from diffusion.utils.data_sampler import AspectRatioBatchSampler
from diffusion.utils.dist_utils import clip_grad_norm_, flush, get_world_size
from diffusion.utils.logger import LogBuffer, get_root_logger
from diffusion.utils.lr_scheduler import build_lr_scheduler
from diffusion.utils.misc import (
    DebugUnderflowOverflow,
    init_random_seed,
    read_config,
    set_random_seed,
)
from diffusion.utils.optimizer import auto_scale_lr, build_optimizer

# 토크나이저 병렬 처리 비활성화
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def set_fsdp_env():
    """
    Fully Sharded Data Parallel(FSDP) 훈련을 위한 환경 변수 설정 함수
    """
    os.environ["ACCELERATE_USE_FSDP"] = "true"
    os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
    os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE"
    os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock"


@torch.inference_mode()
def log_validation(
    accelerator, config, model, logger, step, device, vae=None, init_noise=None
):
    """
    검증을 수행하여 이미지 생성 및 로그 기록을 수행하는 함수
    
    Args:
        accelerator: 분산 훈련을 처리하는 Accelerator 객체
        config: 설정 객체
        model: 디퓨전 모델
        logger: 로그 기록을 위한 로거
        step: 현재 훈련 단계
        device: 계산을 수행할 디바이스
        vae: 이미지 인코딩/디코딩을 위한 VAE (옵션)
        init_noise: 이미지 생성을 위한 초기 노이즈 (옵션)
    
    Returns:
        image_logs: 생성된 이미지와 해당 프롬프트의 리스트
    """
    torch.cuda.empty_cache()
    vis_sampler = config.scheduler.vis_sampler
    model = accelerator.unwrap_model(model).eval()  # 모델을 평가 모드로 설정

    # 이미지 크기 및 종횡비 정의
    hw = torch.tensor([[image_size, image_size]], dtype=torch.float, device=device).repeat(1, 1)
    ar = torch.tensor([[1.0]], device=device).repeat(1, 1)

    # 무조건 생성에 사용할 null 임베딩 로드
    null_y = torch.load(null_embed_path, map_location="cpu")
    null_y = null_y["uncond_prompt_embeds"].to(device)

    logger.info("검증 실행 중... ")
    image_logs = []

    def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"):
        """
        지정된 샘플러를 사용하여 샘플링을 수행하는 헬퍼 함수
        """
        latents = []
        current_image_logs = []
        for prompt in validation_prompts:
            # 잠재 노이즈 초기화
            z = (
                torch.randn(
                    1,
                    config.vae.vae_latent_dim,
                    latent_size,
                    latent_size,
                    device=device,
                )
                if init_z is None
                else init_z
            )
            # 프롬프트 임베딩 로드
            embed = torch.load(
                osp.join(
                    config.train.valid_prompt_embed_root,
                    f"{prompt[:50]}_{valid_prompt_embed_suffix}",
                ),
                map_location="cpu",
            )
            caption_embs, emb_masks = embed["caption_embeds"].to(device), embed[
                "emb_mask"
            ].to(device)

            # 모델 키워드 인자 준비
            model_kwargs = dict(
                data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks
            )

            # 설정에 따라 샘플러 선택
            if sampler == "dpm-solver":
                dpm_solver = DPMS(
                    model.forward_with_dpmsolver,
                    condition=caption_embs,
                    uncondition=null_y,
                    cfg_scale=4.5,
                    model_kwargs=model_kwargs,
                )
                denoised = dpm_solver.sample(
                    z,
                    steps=14,
                    order=2,
                    skip_type="time_uniform",
                    method="multistep",
                )
            elif sampler == "flow_euler":
                flow_solver = FlowEuler(
                    model,
                    condition=caption_embs,
                    uncondition=null_y,
                    cfg_scale=4.5,
                    model_kwargs=model_kwargs,
                )
                denoised = flow_solver.sample(z, steps=28)
            elif sampler == "flow_dpm-solver":
                dpm_solver = DPMS(
                    model.forward_with_dpmsolver,
                    condition=caption_embs,
                    uncondition=null_y,
                    cfg_scale=4.5,
                    model_type="flow",
                    model_kwargs=model_kwargs,
                    schedule="FLOW",
                )
                denoised = dpm_solver.sample(
                    z,
                    steps=20,
                    order=2,
                    skip_type="time_uniform_flow",
                    method="multistep",
                    flow_shift=config.scheduler.flow_shift,
                )
            else:
                raise ValueError(f"{sampler}는 구현되지 않았습니다.")

            latents.append(denoised)

        torch.cuda.empty_cache()

        # VAE 초기화 (제공되지 않은 경우)
        if vae is None:
            vae = get_vae(
                config.vae.vae_type,
                config.vae.vae_pretrained,
                accelerator.device,
            ).to(torch.float16)

        # 잠재 벡터를 이미지로 디코딩
        for prompt, latent in zip(validation_prompts, latents):
            latent = latent.to(torch.float16)
            samples = vae_decode(config.vae.vae_type, vae, latent)
            samples = (
                torch.clamp(127.5 * samples + 128.0, 0, 255)
                .permute(0, 2, 3, 1)
                .to("cpu", dtype=torch.uint8)
                .numpy()[0]
            )
            image = Image.fromarray(samples)
            current_image_logs.append(
                {"validation_prompt": prompt + label_suffix, "images": [image]}
            )

        return current_image_logs

    # 원래 노이즈로 샘플링 실행
    image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler)

    # 초기 노이즈가 제공된 경우 추가 샘플링 실행
    if init_noise is not None:
        init_noise = torch.clone(init_noise).to(device)
        image_logs += run_sampling(
            init_z=init_noise,
            label_suffix=" w/ init noise",
            vae=vae,
            sampler=vis_sampler,
        )

    # 로그를 위한 이미지 형식화
    formatted_images = []
    for log in image_logs:
        images = log["images"]
        validation_prompt = log["validation_prompt"]
        for image in images:
            formatted_images.append((validation_prompt, np.asarray(image)))

    # 다양한 트래커에 이미지 로깅
    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            for validation_prompt, image in formatted_images:
                tracker.writer.add_images(
                    validation_prompt, image[None, ...], step, dataformats="NHWC"
                )
        elif tracker.name == "wandb":
            import wandb

            wandb_images = []
            for validation_prompt, image in formatted_images:
                wandb_images.append(
                    wandb.Image(image, caption=validation_prompt, file_type="jpg")
                )
            tracker.log({"validation": wandb_images})
        else:
            logger.warn(f"{tracker.name}에 대한 이미지 로깅이 구현되지 않았습니다.")

    def concatenate_images(image_caption, images_per_row=5, image_format="webp"):
        """
        여러 이미지를 하나의 그리드 이미지로 결합하는 함수
        """
        import io

        images = [log["images"][0] for log in image_caption]
        if images[0].size[0] > 1024:
            images = [image.resize((1024, 1024)) for image in images]

        widths, heights = zip(*(img.size for img in images))
        max_width = max(widths)
        total_height = sum(
            heights[i : i + images_per_row][0]
            for i in range(0, len(images), images_per_row)
        )

        new_im = Image.new("RGB", (max_width * images_per_row, total_height))

        y_offset = 0
        for i in range(0, len(images), images_per_row):
            row_images = images[i : i + images_per_row]
            x_offset = 0
            for img in row_images:
                new_im.paste(img, (x_offset, y_offset))
                x_offset += max_width
            y_offset += heights[i]
        webp_image_bytes = io.BytesIO()
        new_im.save(webp_image_bytes, format=image_format)
        webp_image_bytes.seek(0)
        new_im = Image.open(webp_image_bytes)

        return new_im

    # 시각화를 로컬에 저장하는 옵션
    if config.train.local_save_vis:
        file_format = "webp"
        local_vis_save_path = osp.join(config.work_dir, "log_vis")
        os.umask(0o000)
        os.makedirs(local_vis_save_path, exist_ok=True)
        concatenated_image = concatenate_images(
            image_logs, images_per_row=5, image_format=file_format
        )
        save_path = (
            osp.join(local_vis_save_path, f"vis_{step}.{file_format}")
            if init_noise is None
            else osp.join(
                local_vis_save_path, f"vis_{step}_w_init.{file_format}"
            )
        )
        concatenated_image.save(save_path)

    # 정리 작업
    del vae
    flush()
    return image_logs


def train(
    config,
    args,
    accelerator,
    model,
    optimizer,
    lr_scheduler,
    train_dataloader,
    train_diffusion,
    logger,
):
    """
    디퓨전 모델의 주요 훈련 루프를 수행하는 함수
    
    Args:
        config: 설정 객체
        args: 명령줄 인자
        accelerator: 분산 훈련을 처리하는 Accelerator 객체
        model: 디퓨전 모델
        optimizer: 훈련을 위한 옵티마이저
        lr_scheduler: 학습률 스케줄러
        train_dataloader: 훈련 데이터의 DataLoader
        train_diffusion: 디퓨전 프로세스를 위한 스케줄러 객체
        logger: 로그 기록을 위한 로거
    """
    if getattr(config.train, "debug_nan", False):
        # 훈련 중 오버플로우를 감지하기 위한 NaN 디버거 등록
        DebugUnderflowOverflow(model)
        logger.info("NaN 디버거가 등록되었습니다. 훈련 중 오버플로우를 감지하기 시작합니다.")
    
    log_buffer = LogBuffer()

    global_step = start_step + 1
    skip_step = max(config.train.skip_step, global_step) % train_dataloader_len
    skip_step = skip_step if skip_step < (train_dataloader_len - 20) else 0
    loss_nan_timer = 0

    # 배치 샘플러를 위한 데이터셋 캐싱
    if args.caching and config.model.multi_scale:
        caching_start = time.time()
        logger.info(
            f"배치 샘플러를 위한 데이터셋 캐싱을 시작합니다: {cache_file}. \n"
            f"시간이 많이 소요될 수 있습니다... 훈련이 시작되지 않습니다."
        )
        train_dataloader.batch_sampler.sampler.set_start(
            max(train_dataloader.batch_sampler.exist_ids, 0)
        )
        accelerator.wait_for_everyone()
        for index, _ in enumerate(train_dataloader):
            accelerator.wait_for_everyone()
            if index % 2000 == 0:
                logger.info(
                    f"rank: {rank}, 캐시된 파일 길이: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}"
                )
                print(
                    f"rank: {rank}, 캐시된 파일 길이: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}"
                )
            if (time.time() - caching_start) / 3600 > 3.7:
                # 캐싱이 약 3.7시간 이상 진행된 경우 저장
                json.dump(
                    train_dataloader.batch_sampler.cached_idx,
                    open(cache_file, "w"),
                    indent=4,
                )
                accelerator.wait_for_everyone()
                break
            if len(train_dataloader.batch_sampler.cached_idx) == len(train_dataloader) - 1000:
                logger.info(
                    f"캐시 저장 중 - rank: {rank}, 캐시된 파일 길이: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}"
                )
                json.dump(
                    train_dataloader.batch_sampler.cached_idx,
                    open(cache_file, "w"),
                    indent=4,
                )
            accelerator.wait_for_everyone()
            continue
        accelerator.wait_for_everyone()
        print(
            f"rank-{rank} 캐시 파일 저장 중: {len(train_dataloader.batch_sampler.cached_idx)}"
        )
        json.dump(
            train_dataloader.batch_sampler.cached_idx,
            open(cache_file, "w"),
            indent=4,
        )
        return

    # 모델 훈련 시작
    for epoch in range(start_epoch + 1, config.train.num_epochs + 1):
        time_start, last_tic = time.time(), time.time()
        sampler = (
            train_dataloader.batch_sampler.sampler
            if (num_replicas > 1 or config.model.multi_scale)
            else train_dataloader.sampler
        )
        sampler.set_epoch(epoch)
        sampler.set_start(
            max((skip_step - 1) * config.train.train_batch_size, 0)
        )
        if skip_step > 1 and accelerator.is_main_process:
            logger.info(f"스킵된 단계: {skip_step}")
        skip_step = 1  # 첫 에폭 이후 스킵 단계 리셋
        data_time_start = time.time()
        data_time_all = 0
        lm_time_all = 0
        vae_time_all = 0
        model_time_all = 0

        for step, batch in enumerate(train_dataloader):
            # 모든 프로세스 동기화
            accelerator.wait_for_everyone()
            data_time_all += time.time() - data_time_start

            # VAE를 사용하여 이미지 인코딩 (사전 계산된 VAE 피처를 로드하지 않는 경우)
            vae_time_start = time.time()
            if load_vae_feat:
                z = batch[0].to(accelerator.device)
            else:
                with torch.no_grad():
                    with torch.amp.autocast(
                        "cuda",
                        enabled=(
                            config.model.mixed_precision == "fp16"
                            or config.model.mixed_precision == "bf16"
                        ),
                    ):
                        z = vae_encode(
                            config.vae.vae_type,
                            vae,
                            batch[0],
                            config.vae.sample_posterior,
                            accelerator.device,
                        )

            accelerator.wait_for_everyone()
            vae_time_all += time.time() - vae_time_start

            clean_images = z
            data_info = batch[3]

            # 텍스트 임베딩 처리
            lm_time_start = time.time()
            if load_text_feat:
                y = batch[1]  # 형상: [배치 크기, 1, N, C]
                y_mask = batch[2]  # 형상: [배치 크기, 1, 1, N]
            else:
                if "T5" in config.text_encoder.text_encoder_name:
                    with torch.no_grad():
                        txt_tokens = tokenizer(
                            batch[1],
                            max_length=max_length,
                            padding="max_length",
                            truncation=True,
                            return_tensors="pt",
                        ).to(accelerator.device)
                        y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][
                            :, None
                        ]
                        y_mask = txt_tokens.attention_mask[:, None, None]
                elif (
                    "gemma" in config.text_encoder.text_encoder_name
                    or "Qwen" in config.text_encoder.text_encoder_name
                ):
                    with torch.no_grad():
                        if not config.text_encoder.chi_prompt:
                            max_length_all = config.text_encoder.model_max_length
                            prompt = batch[1]
                        else:
                            chi_prompt = "\n".join(config.text_encoder.chi_prompt)
                            prompt = [chi_prompt + i for i in batch[1]]
                            num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt))
                            max_length_all = (
                                num_chi_prompt_tokens
                                + config.text_encoder.model_max_length
                                - 2
                            )  # 특수 토큰 조정
                        txt_tokens = tokenizer(
                            prompt,
                            padding="max_length",
                            max_length=max_length_all,
                            truncation=True,
                            return_tensors="pt",
                        ).to(accelerator.device)
                        select_index = [0] + list(
                            range(-config.text_encoder.model_max_length + 1, 0)
                        )  # 첫 번째 BOS 토큰 및 마지막 N-1 토큰 선택
                        y = text_encoder(
                            txt_tokens.input_ids,
                            attention_mask=txt_tokens.attention_mask,
                        )[0][:, None, select_index]
                        y_mask = txt_tokens.attention_mask[:, None, None][
                            :, :, :, select_index
                        ]
                else:
                    print("에러 발생")
                    exit()

            # 배치 내 각 이미지에 대해 랜덤 타임스텝 샘플링
            bs = clean_images.shape[0]
            timesteps = torch.randint(
                0, config.scheduler.train_sampling_steps, (bs,), device=clean_images.device
            ).long()

            if config.scheduler.weighting_scheme in ["logit_normal"]:
                # diffusers의 훈련 유틸리티에서 타임스텝 샘플링 조정
                u = compute_density_for_timestep_sampling(
                    weighting_scheme=config.scheduler.weighting_scheme,
                    batch_size=bs,
                    logit_mean=config.scheduler.logit_mean,
                    logit_std=config.scheduler.logit_std,
                    mode_scale=None,  # 사용되지 않음
                )
                timesteps = (u * config.scheduler.train_sampling_steps).long().to(
                    clean_images.device
                )

            grad_norm = None
            accelerator.wait_for_everyone()
            lm_time_all += time.time() - lm_time_start
            model_time_start = time.time()

            # 그래디언트 누적
            with accelerator.accumulate(model):
                # 그래디언트 초기화
                optimizer.zero_grad()
                # 훈련 손실 계산
                loss_term = train_diffusion.training_losses(
                    model,
                    clean_images,
                    timesteps,
                    model_kwargs=dict(y=y, mask=y_mask, data_info=data_info),
                )
                loss = loss_term["loss"].mean()
                # 역전파
                accelerator.backward(loss)
                # 그래디언트 클리핑 적용
                if accelerator.sync_gradients:
                    grad_norm = accelerator.clip_grad_norm_(
                        model.parameters(), config.train.gradient_clip
                    )
                # 옵티마이저 및 스케줄러 업데이트
                optimizer.step()
                lr_scheduler.step()
                accelerator.wait_for_everyone()
                model_time_all += time.time() - model_time_start

            # NaN 손실 체크
            if torch.any(torch.isnan(loss)):
                loss_nan_timer += 1

            lr = lr_scheduler.get_last_lr()[0]
            logs = {args.loss_report_name: accelerator.gather(loss).mean().item()}
            if grad_norm is not None:
                logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
            log_buffer.update(logs)

            # 지정된 간격마다 로그 기록
            if (step + 1) % config.train.log_interval == 0 or (step + 1) == 1:
                accelerator.wait_for_everyone()
                t = (time.time() - last_tic) / config.train.log_interval
                t_d = data_time_all / config.train.log_interval
                t_m = model_time_all / config.train.log_interval
                t_lm = lm_time_all / config.train.log_interval
                t_vae = vae_time_all / config.train.log_interval
                avg_time = (time.time() - time_start) / (step + 1)
                eta = str(
                    datetime.timedelta(
                        seconds=int(avg_time * (total_steps - global_step - 1))
                    )
                )
                eta_epoch = str(
                    datetime.timedelta(
                        seconds=int(
                            avg_time
                            * (
                                train_dataloader_len
                                - sampler.step_start
                                // config.train.train_batch_size
                                - step
                                - 1
                            )
                        )
                    )
                )
                log_buffer.average()

                current_step = (
                    global_step - sampler.step_start // config.train.train_batch_size
                ) % train_dataloader_len
                current_step = (
                    train_dataloader_len if current_step == 0 else current_step
                )
                info = (
                    f"에폭: {epoch} | 글로벌 스텝: {global_step} | 로컬 스텝: {current_step} // {train_dataloader_len}, "
                    f"총 ETA: {eta}, 에폭 ETA: {eta_epoch}, 시간: 전체:{t:.3f}, 모델:{t_m:.3f}, 데이터:{t_d:.3f}, "
                    f"LM:{t_lm:.3f}, VAE:{t_vae:.3f}, 학습률:{lr:.3e}, 캡션: {batch[5][0]}, "
                )
                info += (
                    f"s:({model.module.h}, {model.module.w}), "
                    if hasattr(model, "module")
                    else f"s:({model.h}, {model.w}), "
                )

                # 손실 및 기타 메트릭을 로그에 추가
                info += ", ".join(
                    [f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]
                )
                last_tic = time.time()
                log_buffer.clear()
                data_time_all = 0
                model_time_all = 0
                lm_time_all = 0
                vae_time_all = 0
                if accelerator.is_main_process:
                    logger.info(info)

            logs.update(lr=lr)
            if accelerator.is_main_process:
                accelerator.log(logs, step=global_step)

            global_step += 1

            # 손실이 너무 자주 NaN인 경우 훈련 중단
            if loss_nan_timer > 20:
                raise ValueError("손실이 너무 자주 NaN입니다. 훈련을 중단합니다.")

            # 지정된 스텝 또는 시간 제한에 도달한 경우 체크포인트 저장
            if (
                global_step % config.train.save_model_steps == 0
                or (time.time() - training_start_time) / 3600 > 3.8
            ):
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    os.umask(0o000)
                    ckpt_saved_path = save_checkpoint(
                        osp.join(config.work_dir, "checkpoints"),
                        epoch=epoch,
                        step=global_step,
                        model=accelerator.unwrap_model(model),
                        optimizer=optimizer,
                        lr_scheduler=lr_scheduler,
                        generator=generator,
                        add_symlink=True,
                    )
                    if (
                        config.train.online_metric
                        and global_step % config.train.eval_metric_step == 0
                        and step > 1
                    ):
                        online_metric_monitor_dir = osp.join(
                            config.work_dir, config.train.online_metric_dir
                        )
                        os.makedirs(online_metric_monitor_dir, exist_ok=True)
                        with open(
                            f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt",
                            "w",
                        ) as f:
                            f.write(osp.join(config.work_dir, "config.py") + "\n")
                            f.write(ckpt_saved_path)

                # 시간 제한에 도달한 경우 훈련 중단
                if (time.time() - training_start_time) / 3600 > 3.8:
                    logger.info(
                        f"시간 제한으로 인해 에폭 {epoch}, 스텝 {global_step}에서 훈련을 중단합니다."
                    )
                    return

            # 지정된 스텝마다 시각화 수행
            if (
                config.train.visualize
                and (
                    global_step % config.train.eval_sampling_steps == 0
                    or (step + 1) == 1
                )
            ):
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    if validation_noise is not None:
                        log_validation(
                            accelerator=accelerator,
                            config=config,
                            model=model,
                            logger=logger,
                            step=global_step,
                            device=accelerator.device,
                            vae=vae,
                            init_noise=validation_noise,
                        )
                    else:
                        log_validation(
                            accelerator=accelerator,
                            config=config,
                            model=model,
                            logger=logger,
                            step=global_step,
                            device=accelerator.device,
                            vae=vae,
                        )

            # 다중 스케일 훈련 중 데드락을 방지하기 위한 조기 종료
            if (
                config.model.multi_scale
                and (
                    train_dataloader_len
                    - sampler.step_start
                    // config.train.train_batch_size
                    - step
                )
                < 30
            ):
                global_step = epoch * train_dataloader_len
                logger.info("현재 반복을 조기 종료합니다.")
                break

            # 데이터 로딩 타이머 리셋
            data_time_start = time.time()

        # 지정된 에폭마다 또는 훈련 종료 시 체크포인트 저장
        if (
            epoch % config.train.save_model_epochs == 0
            or (epoch == config.train.num_epochs and not config.debug)
        ):
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                ckpt_saved_path = save_checkpoint(
                    osp.join(config.work_dir, "checkpoints"),
                    epoch=epoch,
                    step=global_step,
                    model=accelerator.unwrap_model(model),
                    optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    generator=generator,
                    add_symlink=True,
                )

                online_metric_monitor_dir = osp.join(
                    config.work_dir, config.train.online_metric_dir
                )
                os.makedirs(online_metric_monitor_dir, exist_ok=True)
                with open(
                    f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt",
                    "w",
                ) as f:
                    f.write(osp.join(config.work_dir, "config.py") + "\n")
                    f.write(ckpt_saved_path)
        accelerator.wait_for_everyone()


@pyrallis.wrap()
def main(cfg: SanaConfig) -> None:
    """
    훈련 환경 설정, 모델 초기화, 데이터 로드, 훈련 시작을 수행하는 메인 함수
    
    Args:
        cfg: 명령줄 인자 또는 구성 파일에서 파싱된 설정 객체
    """
    global (
        train_dataloader_len,
        start_epoch,
        start_step,
        vae,
        generator,
        num_replicas,
        rank,
        training_start_time,
        load_vae_feat,
        load_text_feat,
        validation_noise,
        text_encoder,
        tokenizer,
        max_length,
        validation_prompts,
        latent_size,
        valid_prompt_embed_suffix,
        null_embed_path,
        image_size,
        cache_file,
        total_steps,
    )

    config = cfg
    args = cfg  # 접근성을 위해 별칭 설정

    training_start_time = time.time()
    load_from = True

    # 체크포인트 재개 처리
    if args.resume_from or config.model.resume_from:
        load_from = False
        config.model.resume_from = dict(
            checkpoint=args.resume_from or config.model.resume_from,
            load_ema=False,
            resume_optimizer=True,
            resume_lr_scheduler=True,
        )

    # 디버깅 설정
    if args.debug:
        config.train.log_interval = 1
        config.train.train_batch_size = min(64, config.train.train_batch_size)
        args.report_to = "tensorboard"

    # 작업 디렉토리 생성
    os.umask(0o000)
    os.makedirs(config.work_dir, exist_ok=True)

    # 프로세스 그룹 초기화 (타임아웃 확장)
    init_handler = InitProcessGroupKwargs()
    init_handler.timeout = datetime.timedelta(seconds=5400)  # 1.5시간

    # FSDP 플러그인이 필요한 경우 설정
    if config.train.use_fsdp:
        init_train = "FSDP"
        from accelerate import FullyShardedDataParallelPlugin
        from torch.distributed.fsdp.fully_sharded_data_parallel import (
            FullStateDictConfig,
        )

        set_fsdp_env()
        fsdp_plugin = FullyShardedDataParallelPlugin(
            state_dict_config=FullStateDictConfig(
                offload_to_cpu=False, rank0_only=False
            ),
        )
    else:
        init_train = "DDP"
        fsdp_plugin = None

    # Accelerator 인스턴스 생성
    accelerator = Accelerator(
        mixed_precision=config.model.mixed_precision,
        gradient_accumulation_steps=config.train.gradient_accumulation_steps,
        log_with=args.report_to,
        project_dir=osp.join(config.work_dir, "logs"),
        fsdp_plugin=fsdp_plugin,
        kwargs_handlers=[init_handler],
    )

    # 로거 설정
    log_name = "train_log.log"
    logger = get_root_logger(osp.join(config.work_dir, log_name))
    logger.info(accelerator.state)

    # 재현성을 위한 랜덤 시드 초기화
    config.train.seed = init_random_seed(
        getattr(config.train, "seed", None)
    )
    set_random_seed(config.train.seed + int(os.environ["LOCAL_RANK"]))
    generator = torch.Generator(device="cpu").manual_seed(config.train.seed)

    # 메인 프로세스인 경우 설정 파일 덤프 및 트래커 초기화
    if accelerator.is_main_process:
        pyrallis.dump(
            config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4
        )
        if args.report_to == "wandb":
            import wandb

            wandb.init(
                project=args.tracker_project_name,
                name=args.name,
                resume="allow",
                id=args.name,
            )

    logger.info(f"설정: \n{config}")
    logger.info(f"월드 사이즈: {get_world_size()}, 시드: {config.train.seed}")
    logger.info(f"초기화 중: {init_train}으로 훈련 시작")

    # 설정에 따른 이미지 및 잠재 크기 정의
    image_size = config.model.image_size
    latent_size = int(image_size) // config.vae.vae_downsample_rate
    pred_sigma = getattr(config.scheduler, "pred_sigma", True)
    learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
    max_length = config.text_encoder.model_max_length
    vae = None

    # 결정적 검증을 위한 초기 노이즈 준비
    validation_noise = (
        torch.randn(
            1,
            config.vae.vae_latent_dim,
            latent_size,
            latent_size,
            device="cpu",
            generator=generator,
        )
        if getattr(config.train, "deterministic_validation", False)
        else None
    )

    # 사전 계산된 VAE 피처를 로드하지 않는 경우 VAE 로드
    if not config.data.load_vae_feat:
        vae = get_vae(
            config.vae.vae_type,
            config.vae.vae_pretrained,
            accelerator.device,
        ).to(torch.float16)

    tokenizer = text_encoder = None
    if not config.data.load_text_feat:
        # 토크나이저 및 텍스트 인코더 초기화
        tokenizer, text_encoder = get_tokenizer_and_text_encoder(
            name=config.text_encoder.text_encoder_name,
            device=accelerator.device,
        )
        text_embed_dim = text_encoder.config.hidden_size
    else:
        text_embed_dim = config.text_encoder.caption_channels

    logger.info(f"VAE 타입: {config.vae.vae_type}")

    # 복잡한 인간 지시사항 처리
    if config.text_encoder.chi_prompt:
        chi_prompt = "\n".join(config.text_encoder.chi_prompt)
        logger.info(f"복잡한 인간 지시사항: {chi_prompt}")

    # 검증 임베딩 경로 준비
    os.makedirs(config.train.null_embed_root, exist_ok=True)
    null_embed_path = osp.join(
        config.train.null_embed_root,
        f"null_embed_diffusers_{config.text_encoder.text_encoder_name}_{max_length}token_{text_embed_dim}.pth",
    )

    # 시각화 임베딩 준비 (활성화된 경우)
    if config.train.visualize and len(config.train.validation_prompts):
        valid_prompt_embed_suffix = f"{max_length}token_{config.text_encoder.text_encoder_name}_{text_embed_dim}.pth"
        validation_prompts = config.train.validation_prompts
        skip = True

        # chi_prompt가 있는 경우 고유 식별자 생성
        if config.text_encoder.chi_prompt:
            uuid_chi_prompt = hashlib.sha256(chi_prompt.encode()).hexdigest()
        else:
            uuid_chi_prompt = hashlib.sha256(b"").hexdigest()

        config.train.valid_prompt_embed_root = osp.join(
            config.train.valid_prompt_embed_root, uuid_chi_prompt
        )
        Path(config.train.valid_prompt_embed_root).mkdir(parents=True, exist_ok=True)

        if config.text_encoder.chi_prompt:
            # 복잡한 인간 지시사항을 파일로 저장
            chi_prompt_file = osp.join(
                config.train.valid_prompt_embed_root, "chi_prompt.txt"
            )
            with open(chi_prompt_file, "w", encoding="utf-8") as f:
                f.write(chi_prompt)

        # 프롬프트 임베딩이 이미 존재하는지 확인
        for prompt in validation_prompts:
            prompt_embed_path = osp.join(
                config.train.valid_prompt_embed_root,
                f"{prompt[:50]}_{valid_prompt_embed_suffix}",
            )
            if not (osp.exists(prompt_embed_path) and osp.exists(null_embed_path)):
                skip = False
                logger.info("시각화를 위한 프롬프트 임베딩 준비 중...")
                break

        # 임베딩이 없는 경우 생성
        if accelerator.is_main_process and not skip:
            if config.data.load_text_feat and (
                tokenizer is None or text_encoder is None
            ):
                logger.info(
                    f"{config.text_encoder.text_encoder_name}에서 텍스트 인코더와 토크나이저 로드 중..."
                )
                tokenizer, text_encoder = get_tokenizer_and_text_encoder(
                    name=config.text_encoder.text_encoder_name
                )

            for prompt in validation_prompts:
                prompt_embed_path = osp.join(
                    config.train.valid_prompt_embed_root,
                    f"{prompt[:50]}_{valid_prompt_embed_suffix}",
                )
                if "T5" in config.text_encoder.text_encoder_name:
                    # T5 토크나이저 및 인코더를 사용하여 프롬프트 인코딩
                    txt_tokens = tokenizer(
                        prompt,
                        max_length=max_length,
                        padding="max_length",
                        truncation=True,
                        return_tensors="pt",
                    ).to(accelerator.device)
                    caption_emb = text_encoder(
                        txt_tokens.input_ids,
                        attention_mask=txt_tokens.attention_mask,
                    )[0]
                    caption_emb_mask = txt_tokens.attention_mask
                elif (
                    "gemma" in config.text_encoder.text_encoder_name
                    or "Qwen" in config.text_encoder.text_encoder_name
                ):
                    # GEMMA 또는 Qwen 토크나이저 및 인코더를 사용하여 프롬프트 인코딩
                    if not config.text_encoder.chi_prompt:
                        max_length_all = config.text_encoder.model_max_length
                        prompt = batch[1]
                    else:
                        chi_prompt = "\n".join(config.text_encoder.chi_prompt)
                        prompt = [chi_prompt + i for i in batch[1]]
                        num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt))
                        max_length_all = (
                            num_chi_prompt_tokens
                            + config.text_encoder.model_max_length
                            - 2
                        )  # 특수 토큰 조정

                    txt_tokens = tokenizer(
                        prompt,
                        max_length=max_length_all,
                        padding="max_length",
                        truncation=True,
                        return_tensors="pt",
                    ).to(accelerator.device)
                    select_index = [0] + list(
                        range(-config.text_encoder.model_max_length + 1, 0)
                    )  # 첫 번째 BOS 토큰 및 마지막 N-1 토큰 선택
                    caption_emb = text_encoder(
                        txt_tokens.input_ids,
                        attention_mask=txt_tokens.attention_mask,
                    )[0][:, select_index]
                    caption_emb_mask = txt_tokens.attention_mask[:, select_index]
                else:
                    raise ValueError(
                        f"{config.text_encoder.text_encoder_name}은(는) 지원되지 않습니다!!"
                    )

                # 프롬프트 임베딩 저장
                torch.save(
                    {
                        "caption_embeds": caption_emb,
                        "emb_mask": caption_emb_mask,
                    },
                    prompt_embed_path,
                )

            # null (무조건적) 임베딩 생성 및 저장
            null_tokens = tokenizer(
                "", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
            ).to(accelerator.device)
            if "T5" in config.text_encoder.text_encoder_name:
                null_token_emb = text_encoder(
                    null_tokens.input_ids, attention_mask=null_tokens.attention_mask
                )[0]
            elif (
                "gemma" in config.text_encoder.text_encoder_name
                or "Qwen" in config.text_encoder.text_encoder_name
            ):
                null_token_emb = text_encoder(
                    null_tokens.input_ids, attention_mask=null_tokens.attention_mask
                )[0]
            else:
                raise ValueError(
                    f"{config.text_encoder.text_encoder_name}은(는) 지원되지 않습니다!!"
                )
            torch.save(
                {
                    "uncond_prompt_embeds": null_token_emb,
                    "uncond_prompt_embeds_mask": null_tokens.attention_mask,
                },
                null_embed_path,
            )
            if config.data.load_text_feat:
                del tokenizer
                del text_encoder
            del null_token_emb
            del null_tokens
            flush()

    # 선형 어텐션 오토캐스트를 위한 환경 변수 설정
    os.environ["AUTOCAST_LINEAR_ATTN"] = (
        "true" if config.model.autocast_linear_attn else "false"
    )

    # 1. 디퓨전 스케줄러 빌드
    train_diffusion = Scheduler(
        str(config.scheduler.train_sampling_steps),
        noise_schedule=config.scheduler.noise_schedule,
        predict_v=config.scheduler.predict_v,
        learn_sigma=learn_sigma,
        pred_sigma=pred_sigma,
        snr=config.train.snr_loss,
        flow_shift=config.scheduler.flow_shift,
    )
    predict_info = f"v-예측: {config.scheduler.predict_v}, 노이즈 스케줄: {config.scheduler.noise_schedule}"
    if "flow" in config.scheduler.noise_schedule:
        predict_info += f", 플로우 시프트: {config.scheduler.flow_shift}"
    if config.scheduler.weighting_scheme in ["logit_normal", "mode"]:
        predict_info += (
            f", 플로우 가중치: {config.scheduler.weighting_scheme}, "
            f"logit-평균: {config.scheduler.logit_mean}, logit-표준편차: {config.scheduler.logit_std}"
        )
    logger.info(predict_info)

    # 2. 디퓨전 모델 빌드
    model_kwargs = {
        "pe_interpolation": config.model.pe_interpolation,
        "config": config,
        "model_max_length": max_length,
        "qk_norm": config.model.qk_norm,
        "micro_condition": config.model.micro_condition,
        "caption_channels": text_embed_dim,
        "y_norm": config.text_encoder.y_norm,
        "attn_type": config.model.attn_type,
        "ffn_type": config.model.ffn_type,
        "mlp_ratio": config.model.mlp_ratio,
        "mlp_acts": list(config.model.mlp_acts),
        "in_channels": config.vae.vae_latent_dim,
        "y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
        "use_pe": config.model.use_pe,
        "linear_head_dim": config.model.linear_head_dim,
        "pred_sigma": pred_sigma,
        "learn_sigma": learn_sigma,
    }
    model = build_model(
        config.model.model,
        config.train.grad_checkpointing,
        getattr(config.model, "fp32_attention", False),
        input_size=latent_size,
        **model_kwargs,
    ).train()  # 모델을 훈련 모드로 설정
    logger.info(
        colored(
            f"{model.__class__.__name__}:{config.model.model}, "
            f"모델 파라미터 수: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M",
            "green",
            attrs=["bold"],
        )
    )

    # 2-1. 모델 체크포인트 로드
    if args.load_from is not None:
        config.model.load_from = args.load_from
    if config.model.load_from is not None and load_from:
        _, missing, unexpected, _ = load_checkpoint(
            config.model.load_from,
            model,
            load_ema=config.model.resume_from.get("load_ema", False),
            null_embed_path=null_embed_path,
        )
        logger.warning(f"누락된 키: {missing}")
        logger.warning(f"예상치 못한 키: {unexpected}")

    # FSDP를 사용하는 경우 그래디언트 노름 계산을 위한 준비
    if accelerator.distributed_type == DistributedType.FSDP:
        for m in accelerator._models:
            m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m)

    # 3. 훈련 데이터 로더 빌드
    config.data.data_dir = (
        config.data.data_dir
        if isinstance(config.data.data_dir, list)
        else [config.data.data_dir]
    )
    config.data.data_dir = [
        data
        if data.startswith(("https://", "http://", "gs://", "/", "~"))
        else osp.abspath(osp.expanduser(data))
        for data in config.data.data_dir
    ]
    num_replicas = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    dataset = build_dataset(
        asdict(config.data),
        resolution=image_size,
        aspect_ratio_type=config.model.aspect_ratio_type,
        real_prompt_ratio=config.train.real_prompt_ratio,
        max_length=max_length,
        config=config,
        caption_proportion=config.data.caption_proportion,
        sort_dataset=config.data.sort_dataset,
        vae_downsample_rate=config.vae.vae_downsample_rate,
    )
    accelerator.wait_for_everyone()

    if config.model.multi_scale:
        # 다중 스케일 훈련을 위한 캐싱 처리
        drop_last = True
        uuid = hashlib.sha256("-".join(config.data.data_dir).encode()).hexdigest()[:8]
        cache_dir = osp.expanduser(f"~/.cache/_wids_batchsampler_cache")
        os.makedirs(cache_dir, exist_ok=True)
        base_pattern = (
            f"{cache_dir}/{getpass.getuser()}-{uuid}-sort_dataset{config.data.sort_dataset}"
            f"-hq_only{config.data.hq_only}-valid_num{config.data.valid_num}"
            f"-aspect_ratio{len(dataset.aspect_ratio)}-droplast{drop_last}"
            f"dataset_len{len(dataset)}"
        )
        cache_file = f"{base_pattern}-num_replicas{num_replicas}-rank{rank}"
        for i in config.data.data_dir:
            cache_file += f"-{i}"
        cache_file += ".json"

        # 샘플러 및 배치 샘플러 초기화
        sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank)
        batch_sampler = AspectRatioBatchSampler(
            sampler=sampler,
            dataset=dataset,
            batch_size=config.train.train_batch_size,
            aspect_ratios=dataset.aspect_ratio,
            drop_last=drop_last,
            ratio_nums=dataset.ratio_nums,
            config=config,
            valid_num=config.data.valid_num,
            hq_only=config.data.hq_only,
            cache_file=cache_file,
            caching=args.caching,
        )
        train_dataloader = build_dataloader(
            dataset, batch_sampler=batch_sampler, num_workers=config.train.num_workers
        )
        train_dataloader_len = len(train_dataloader)
        logger.info(f"rank-{rank} 캐시된 파일 길이: {len(train_dataloader.batch_sampler.cached_idx)}")
    else:
        # 일반 훈련을 위한 샘플러 및 데이터 로더 초기화
        sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank)
        train_dataloader = build_dataloader(
            dataset,
            num_workers=config.train.num_workers,
            batch_size=config.train.train_batch_size,
            shuffle=False,
            sampler=sampler,
        )
        train_dataloader_len = len(train_dataloader)

    load_vae_feat = getattr(train_dataloader.dataset, "load_vae_feat", False)
    load_text_feat = getattr(train_dataloader.dataset, "load_text_feat", False)

    # 4. 옵티마이저 및 학습률 스케줄러 빌드
    lr_scale_ratio = 1
    if getattr(config.train, "auto_lr", None):
        lr_scale_ratio = auto_scale_lr(
            config.train.train_batch_size
            * get_world_size()
            * config.train.gradient_accumulation_steps,
            config.train.optimizer,
            **config.train.auto_lr,
        )
    optimizer = build_optimizer(model, config.train.optimizer)
    if (
        config.train.lr_schedule_args
        and config.train.lr_schedule_args.get("num_warmup_steps", None)
    ):
        config.train.lr_schedule_args["num_warmup_steps"] = (
            config.train.lr_schedule_args["num_warmup_steps"] * num_replicas
        )
    lr_scheduler = build_lr_scheduler(
        config.train, optimizer, train_dataloader, lr_scale_ratio
    )
    logger.warning(
        f"{colored(f'기본 설정: ', 'green', attrs=['bold'])}"
        f"학습률: {config.train.optimizer['lr']:.5f}, 배치 크기: {config.train.train_batch_size}, 그래디언트 체크포인팅: {config.train.grad_checkpointing}, "
        f"그래디언트 누적 단계: {config.train.gradient_accumulation_steps}, QK 노름: {config.model.qk_norm}, "
        f"FP32 어텐션: {config.model.fp32_attention}, 어텐션 타입: {config.model.attn_type}, FFN 타입: {config.model.ffn_type}, "
        f"텍스트 인코더: {config.text_encoder.text_encoder_name}, 캡션 비율: {config.data.caption_proportion}, 정밀도: {config.model.mixed_precision}"
    )

    timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())

    # 트래커 초기화 (예: TensorBoard, WandB)
    if accelerator.is_main_process:
        tracker_config = dict(vars(config))
        try:
            accelerator.init_trackers(args.tracker_project_name, tracker_config)
        except:
            accelerator.init_trackers(f"tb_{timestamp}")

    start_epoch = 0
    start_step = 0
    total_steps = train_dataloader_len * config.train.num_epochs

    # 체크포인트에서 훈련 재개
    if config.model.resume_from is not None and config.model.resume_from["checkpoint"] is not None:
        rng_state = None
        ckpt_path = osp.join(config.work_dir, "checkpoints")
        check_flag = osp.exists(ckpt_path) and len(os.listdir(ckpt_path)) != 0
        if config.model.resume_from["checkpoint"] == "latest":
            if check_flag:
                checkpoints = os.listdir(ckpt_path)
                if (
                    "latest.pth" in checkpoints
                    and osp.exists(osp.join(ckpt_path, "latest.pth"))
                ):
                    config.model.resume_from["checkpoint"] = osp.realpath(
                        osp.join(ckpt_path, "latest.pth")
                    )
                else:
                    checkpoints = [
                        i for i in checkpoints if i.startswith("epoch_")
                    ]
                    checkpoints = sorted(
                        checkpoints, key=lambda x: int(x.replace(".pth", "").split("_")[3])
                    )
                    config.model.resume_from["checkpoint"] = osp.join(
                        ckpt_path, checkpoints[-1]
                    )
            else:
                config.model.resume_from["checkpoint"] = config.model.load_from

        if config.model.resume_from["checkpoint"] is not None:
            _, missing, unexpected, rng_state = load_checkpoint(
                **config.model.resume_from,
                model=model,
                optimizer=optimizer if check_flag else None,
                lr_scheduler=lr_scheduler if check_flag else None,
                null_embed_path=null_embed_path,
            )

            logger.warning(f"누락된 키: {missing}")
            logger.warning(f"예상치 못한 키: {unexpected}")

            path = osp.basename(config.model.resume_from["checkpoint"])
        try:
            start_epoch = int(path.replace(".pth", "").split("_")[1]) - 1
            start_step = int(path.replace(".pth", "").split("_")[3])
        except:
            pass

        # 랜덤 시드 상태 재개
        if rng_state:
            logger.info("랜덤 시드 상태를 재개합니다.")
            torch.set_rng_state(rng_state["torch"])
            torch.cuda.set_rng_state_all(rng_state["torch_cuda"])
            np.random.set_state(rng_state["numpy"])
            random.setstate(rng_state["python"])
            generator.set_state(rng_state["generator"])  # 제너레이터 상태 재개

    # Accelerator를 사용하여 모델, 옵티마이저, 스케줄러 준비
    model = accelerator.prepare(model)
    optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)

    # 훈련 프로세스 시작
    train(
        config=config,
        args=args,
        accelerator=accelerator,
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_dataloader=train_dataloader,
        train_diffusion=train_diffusion,
        logger=logger,
    )


if __name__ == "__main__":

    main()

 

 

 

 

 

 

코드 요약 및 모델 구조

 

1) 개요

PyTorch와 accelerate 라이브러리를 사용하여 텍스트-이미지 디퓨전 모델을 훈련시키기 위한 포괄적인 파이프라인으로 구현되어 있다. 이 모델은 이미지 인코딩 및 디코딩을 위한 변분 오토인코더(VAE)와 텍스트 프롬프트를 처리하기 위한 텍스트 인코더(T5, GEMMA, Qwen 등)를 활용하는데, 훈련 파이프라인은 데이터 로딩, 모델 초기화, 훈련 루프, 검증, 로깅, 체크포인팅 등을 포함합니다.

 

2) 주요 구성 요소

(1) 모델 구성 요소

  • 변분 오토인코더(VAE): 입력 이미지를 잠재 공간으로 인코딩(vae_encode)하고, 잠재 벡터를 이미지로 디코딩(vae_decode)하는 역할
  • 텍스트 인코더: 텍스트 프롬프트를 임베딩으로 변환합니다. T5, GEMMA, Qwen 등 다양한 아키텍처를 지원
  • 디퓨전 스케줄러: 훈련 및 샘플링 중 디퓨전 프로세스를 관리합니다. DPMS, FlowEuler, FlowDPM-Solver 등의 샘플링 방법을 지원
  • 메인 모델: build_model 함수를 사용하여 다양한 설정(어텐션 유형, 정규화 레이어 등)을 포함한 디퓨전 모델을 구축

(2) 데이터 처리

  • 데이터셋 빌드: build_dataset 함수를 사용하여 이미지 해상도, 종횡비, 데이터 특정 설정 등을 처리
  • 데이터 샘플러: 분산 훈련을 위한 DistributedRangedSampler와 다중 스케일 훈련을 위한 AspectRatioBatchSampler를 사용하여 효율적인 데이터 로딩과 배칭을 보장
  • DataLoader: 적절한 샘플러, 배치 크기, 워커 수로 train_dataloader를 구축

(3) 훈련 루프 

  • 에폭 관리: 에폭과 배치를 반복하면서 그래디언트 누적과 분산 프로세스 간 동기화를 처리
  • 순전파: 이미지를 인코딩하고 텍스트를 처리하며, 랜덤 타임스텝을 샘플링하고 디퓨전 모델을 사용하여 손실을 계산
  • 역전파: 손실을 역전파하고 그래디언트 클리핑을 적용한 후 옵티마이저와 스케줄러를 업데이트
  • 로깅 및 시각화: 지정된 간격마다 훈련 메트릭을 로그하고, 검증 프롬프트를 사용하여 생성된 이미지를 시각화

(4) 분산 훈련

  • Accelerate 라이브러리: Accelerate 라이브러리를 사용하여 다중 GPU 및 다중 노드 훈련을 효율적으로 처리
  • FSDP 지원: 대형 모델의 메모리 효율성을 높이기 위해 Fully Sharded Data Parallel(FSDP)을 옵션으로 활성화

(2) 모델 아키텍처

모델 아키텍처는 모듈식이며 구성 가능하여 다양한 컴포넌트와 훈련 전략을 실험할 수 있도록 유연성을 제공합니다. 주요 구성 요소는 다음과 같습니다:

(1) 변분 오토인코더(VAE)

  • 인코더: 입력 이미지를 config.vae.vae_latent_dim 차원의 잠재 공간으로 압축
  • 디코더: 잠재 벡터에서 이미지를 재구성
  • 사용 용도: 입력 이미지를 잠재 벡터로 인코딩(vae_encode)하고, 잠재 벡터를 다시 이미지로 디코딩(vae_decode)

(2) 텍스트 인코더

  • 지원 아키텍처: T5, GEMMA, Qwen.
  • 기능: 텍스트 프롬프트를 디퓨전 프로세스를 조건화하기 위한 임베딩으로 변환
  • 정규화: 구성 가능한 정규화(y_norm)와 스케일링 팩터(y_norm_scale_factor).

(3) 디퓨전 스케줄러

  • 스케줄러 종류: DPMS, FlowEuler, FlowDPM-Solver.
  • 기능: 훈련 및 샘플링 중 디퓨전 프로세스를 관리하며, 노이즈 추가 및 노이즈 제거 단계를 포함
  • 매개변수: 샘플링 단계 수, 노이즈 스케줄 유형, 시그마 예측 여부(pred_sigma), 기타 설정.

(4) 메인 디퓨전 모델

  • 컴포넌트:
    • 어텐션 메커니즘: 구성 가능한 어텐션 유형(attn_type) 및 선택적인 혼합 정밀도 선형 어텐션(autocast_linear_attn).
    • 피드포워드 네트워크(FFN): 구성 가능한 FFN 유형(ffn_type) 및 활성화 함수(mlp_acts).
    • 정규화 레이어: 구성 가능한 정규화(qk_norm, y_norm).
  • 입력: VAE에서 나온 잠재 벡터와 텍스트 인코더에서 나온 텍스트 임베딩.
  • 출력: 설정에 따라 예측된 노이즈 또는 다른 디퓨전 관련 출력.

'A.I.(인공지능) & M.L.(머신러닝)' 카테고리의 다른 글

런웨이 & 루마 AI API  (1) 2024.09.22
CogVideoX  (0) 2024.09.04
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유