勾配降下党青年局

万国のグラーディエントよ、降下せよ!

SD XLの損失関数について確認

SD XLの実装において、損失計算のアルゴリズムが異なり困惑したので、確認してみます。最終的には今までの損失と同値になることが分かりました。
時刻t\in(0,\cdots,T)の潜在変数をx_t、ノイズを\epsilon、UNetをfとします。diffusion modelでは、拡散過程tステップは以下のようにあらわされます。
x_t = \sqrt{\bar{\alpha_t}}x_{0} + \sqrt{1-\bar{\alpha_t}}\epsilon
ここで\bar{\alpha_t}はスケジューラが持っているハイパーパラメータであり、SDv1, SDv2, SD XLで数値は多分変わりません。(deepfloyd ifは違ったりする)

損失は、UNetがノイズを予測するよう学習させるため、ノイズと予測ノイズの二乗誤差になります。
L = \| \epsilon - f(x_t) \|^2
(※sdv2-768系はちょっと違います。)
diffusersによる訓練を超簡易的に書くと、

for x_0, c in dataset: # 実画像や条件付けをサンプリング
    t = torch.randint(0,T)
    noise = torch.randn_like(x_0)
    x_t = scheduler.add_noise(x_0, t)
    noise_pred = unet(x_t, t, c)
    loss = mse(noise, noise_pred)
    ...

となりますが、それに対してSD XLでの損失計算は以下のような流れになっているっぽいです。

for x_0, c in dataset: # 実画像や条件付けをサンプリング
    t = torch.randint(0,T)
    noise = torch.randn_like(x_0)
    sigma = scheduler.get_sigmas(t)
    x_t_pre = x_0 + noise * sigma
    x_t /= (sigma ** 2 + 1) ** 0.5
    noise_pred = unet(x_t, t, c)
    x_0_pred = x_t_pre - noise_pred * sigma 
    loss = mse(x_0, x_0_pred) / (sigma ** 2)
    ...

違いとして、sigmaという新しい値がでてきたり、最後のlossがノイズではなく元の潜在変数の二乗誤差になっています。
sigmaは以下の数式で表されます。
\sigma_t = \frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}} = \frac{1}{\sqrt{\mathrm{SNR}(t)}}
ここで、\mathrm{SNR}(t)とは信号対雑音比のことで、信号の分散をノイズの分散で割ったものです。
それではまず両コードのx_tが同じであることを確認します。
SD XL側の実装において、
x_{t_{pre}} = x_0 + \epsilon\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}}
であり、次に(sigma ** 2 + 1) ** 0.5ですが、これは数式にすると、
 \sqrt{(\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}})^2 + 1} = \sqrt{(\frac{1-\bar{\alpha_t}}{\bar{\alpha_t}}) + 1} = \frac{1}{\sqrt{\bar{\alpha_t}}}
となります。

これが割られるので、x_0に掛けられる係数は\sqrt{\bar{\alpha_t}}、noiseに掛けられる係数は\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}}\sqrt{\bar{\alpha_t}} =\sqrt{1-\bar{\alpha_t}}
となり、ちゃんとx_tと同値になることが分かります。unetがノイズを予測できていれば、x_0_predは元のx_0に近づきます。損失はその二つの平均二乗誤差です。ただし損失に\frac{1}{\sigma^2}=\mathrm{SNR}(t)が掛けられています。これは以前書いた記事にちょうど関連する話であり、\|\epsilon - \hat{\epsilon}\|^2=\mathrm{SNR}(t)\|x_0 - \hat{x_0}\|^2という関係から最終的な損失はノイズの二乗誤差になることが分かります。

 何でこんな回りくどいやり方にしたの?と思われるかもしれませんが、どうやらEulerなどの微分方程式ソルバー系の実装に合わせているようですね。まあこの辺り私はよく分かってません。