勾配降下党青年局

万国のグラーディエントよ、降下せよ!

Euler vs DDIM

ComfyUI上でDDIMの実装が見たくて探していたのですが、こんな感じでした。あれ・・・Eulerを呼び出してるだけ・・・?というわけで確認していきます。

以下の記事を前提とする。
さんぷらーについて - 勾配降下党青年局


DDIMのデフォルト設定(\eta=0)のとき時刻t\to t-1の更新式は、
 
\begin{eqnarray}
x_{t-1} &=& \sqrt{\bar{\alpha}_{t-1}}\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(x_t, t))+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon_\theta(x_t, t) \\
&=& \frac{1}{\sqrt{\alpha_t}}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(x_t, t))+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon_\theta(x_t, t) \\
&=&  \frac{x_t}{\sqrt{\alpha_t}} - \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}}\epsilon_\theta(x_t, t)+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon_\theta(x_t, t) \\
&=&  \frac{x_t}{\sqrt{\alpha_t}} +(\sqrt{1-\bar{\alpha}_{t-1}} - \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}})\epsilon_\theta(x_t, t) \\
\end{eqnarray}

一方Eulerはx'_tを分散発散型の変数とすると、
 
\begin{eqnarray}
x'_{t-1} &=& x'_t + (\sigma_{t-1} - \sigma_{t})\epsilon_\theta(x_t, t) \\
&=& x'_t + (\frac{\sqrt{1-\bar{\alpha}_{t-1}}}{\sqrt{\bar{\alpha}_{t-1}}} - \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}})\epsilon_\theta(x_t, t) \\
\end{eqnarray}
となります。ここで上記記事にあるように分散発散型を分散保存型に変換x_t = \sqrt{\bar{\alpha}_t}x'_tすると、
 
\begin{eqnarray}
\frac{x_{t-1}}{\sqrt{\bar{\alpha}_{t-1}}} &=& \frac{1}{\sqrt{\bar{\alpha}_t}}x_t + (\frac{\sqrt{1-\bar{\alpha}_{t-1}}}{\sqrt{\bar{\alpha}_{t-1}}} - \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}})\epsilon_\theta(x_t, t) \\
&=& \frac{x_t}{\sqrt{\alpha_t}} +(\sqrt{1-\bar{\alpha}_{t-1}} - \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\alpha_t}})\epsilon_\theta(x_t, t)  \\
\end{eqnarray}

同じなんですね。全然知りませんでした。

LyCORISのアルゴリズムまとめ

なんか色々増えてきたのでまとめるよ。

りこりこ

LoRA

LoRAは行列W\in \mathbb{R}^{(m,n)}に対して、差分\Delta W \in \mathbb{R}^{(m,n)}を学習します。このとき、A\in\mathbb{R}^{(m,r)}, \ B\in\mathbb{R}^{(r,n)}という二つの行列で\Delta W = ABとすることで、学習対象のパラメータを大幅に削減します。1層目をdown層、2層目をup層と呼びます。学習時は入力x\in\mathbb{(r,b)}に対して、出力をy=Wx+A(Bx)とします。これは計算量的には大きくなりますが、学習対象パラメータが大幅に削減されるメリットの方が大きいです。推論時はW'=W+ABとすることで、計算量を増やさずに済みます。実際はLoRAの結果を\frac{\alpha}{r}でスケーリングします。\alphaはハイパーパラメータで、rの増減でLoRAのスケールが変動することを防ぎ、rを変えた時に学習率を変更せずに済むようになっています。
アルゴリズムの総称として使われることもありますが、一番狭い意味では全結合層及び1×1畳み込み層に適用するものになっています。1×1畳み込み層にはピクセルごとに同じ全結合層を適用しているだけなので簡単に表現できます。
r \le \mathrm{rank}(\Delta W)となります。このことからrはrankと呼ばれています。上界だけどね。
r=1のLoRAを考えてみましょう。するとLoRAは列ベクトルと行ベクトルの積になります。Aの各要素をa_iとすると、
\Delta W = \begin{pmatrix}
a_1B \\ \vdots \\  a_mB
\end{pmatrix}
です。行基本変形を繰り返せば1行を残し全部0になることがすぐ分かり、\mathrm{rank}(\Delta W)=1になります。
一般のrのときは、
\Delta W = \begin{pmatrix}
a_{11}B_1 + a_{12}B_2 + \cdots + a_{1r}B_r \\ a_{21}B_1 + a_{22}B_2 + \cdots + a_{2r}B_r \\ \vdots \\  a_{m1}B_1 + a_{m2}B_2 + \cdots + a_{rm}B_r
\end{pmatrix}
みたいな感じになります。上のr行目までで行基本変形を行えば三角行列的な形(足し算に対してだけど)にできるので、r+1行目以降を0にできます。よって\mathrm{rank}(\Delta W)\le rです。

LoCon

LoConはフィルターサイズが1×1以外の畳み込み層にもLoRAを適用する方法です。たとえば3×3フィルターの場合、1ピクセルの計算にスポットを当てれば、9マス×入力チャンネル⇒出力チャンネルの全結合層です。3×3畳み込み層の重みは\mathbb{R}^{(m,n,3,3)}なので数は合いますよね。LoRAの場合はこの変換を9マス×入力チャンネル⇒rチャンネル⇒出力チャンネルと二つの変換に分解します。A\in\mathbb{R}^{(m,r)}, B\in\mathbb{R}^{(r,n,3,3)}として、3×3畳み込み層にもLoRAを拡張できます。入力を合わせるために、down層のstrideやpaddingは元のモジュールと同じものにします。up層は1×1畳み込みです。

LoHa

LoHaは二つのLoRAのアダマール積をとる手法です。アダマール積とは単純に行列の要素ごとの積をとる演算です。
\Delta W = AB \odot CDです。
LoRAの項目があった形になおすと、
\Delta W = \begin{pmatrix}
a_{11}B_1 + a_{12}B_2 + \cdots + a_{1r}B_r \\ a_{21}B_1 + a_{22}B_2 + \cdots + a_{2r}B_r \\ \vdots \\  a_{m1}B_1 + a_{m2}B_2 + \cdots + a_{rm}B_r
\end{pmatrix}
\odot
\begin{pmatrix}
c_{11}D_1 + c_{12}D_2 + \cdots + c_{1r}D_r \\ c_{21}D_1 + c_{22}D_2 + \cdots + c_{2r}D_r \\ \vdots \\  c_{m1}D_1 + c_{m2}D_2 + \cdots + c_{rm}D_r
\end{pmatrix}
r個の足し算同士の掛け算によって、項がr^2個になります。よって\Delta W \le r^2になります。パラメータ数が2倍になるのに対して、rankの上界が2乗になります。

LoKr

LoKRはクロネッカー積を使う手法です。
クロネッカー積とは行列A,Bに対して、以下のような行列を返す演算です。
A \otimes B = 
\begin{pmatrix}
a_{11}B & \cdots &  a_{1n}B \\
\vdots & \ddots & \vdots \\
a_{m1}B &\cdots & a_{mn}B \\
\end{pmatrix}

(m,n)行列と(p,q)行列に対してクロネッカー積を適用すると、(mp,nq)行列になります。

\Delta W = W_1 \otimes ABとなります。右だけLoRAです。左までLoRAにする設定もあるようです。
クロネッカー積に関して、\mathrm{rank}(A\otimes B) = \mathrm{rank}(A)\mathrm{rank}(B)となるので、rank的にはめちゃくちゃ効率よくなります。ただし重み共有を使っているので、rank的にいいからといって表現力が高いかというと微妙だと思いますけどね。

IA3

各モジュールの入力もしくは出力をチャンネルごとにスケーリングします。数式的には対角行列を重みの左側もしくは右側にかけます。パラメータ数は非常に小さいです。入出力両方に適用すれば素朴に表現力があがりそうですが、そういった設定はないようですね。

GLoRA

GLoRAは差分だけでなく、元の重みへの入力に対してもLoRAを適用します。
W' = W(x + ABx) + CDx = (W+WAB+CD)x
元論文ではもっと複雑な設計ですが、省略した実装のようですね。これをLoKR版にしたGLoKRもあります。
LoRAと比べて表現力は増えていないんですが、元の重みを利用することで学習が早くなったりするんですかね(よく分かりません)。

DyLoRA

複数のrankのLoRAを同時に学習する方法です。LoRAはAB=\displaystyle{\sum_{i=1}^{r}A_iB_i}という形にできます。ここでA_iAi列目、B_iBi行目です。DyLoRAではステップごとにランダムにkを選び、\Delta W = \mathrm{no\_grad}(\displaystyle{\sum_{i=1}^{k-1}A_iB_i}) + \mathrm{enable\_grad}(A_kB_k)とします。このように学習すると、任意のrank(k)で、\Delta W=\displaystyle{\sum_{i=1}^{k}A_iB_i}として生成できるようになります。kのrankを学習するときに、k-1以下のrankのLoRAに影響を与えないため勾配を無効にしています。rankは1ずつではなくブロック分けすることも可能です。

Norm

GroupNormやLayerNorm層も学習対象にします。重みやバイアスの差分を学習するというだけで特に難しいことはありません。

Full

\Delta Wを元の重みと同じサイズの行列にします。つまりやっていることは普通のファインチューニングと同じです。差分をとっておくのはVRAMの浪費でしかないんですが、LyCORIS上でフルファインチューニングを実装するにはそうするしかなかったんでしょう。一応差分に対してドロップアウトを実行できるというメリットはありますけどお。

時刻をスキップするときの拡散過程

拡散モデルで定義される、データからノイズを加えていく拡散過程は以下のように定義されています。
q(x_t|x_{t-1})=N(x_t;\sqrt{\alpha_t}, (1-\alpha_t)\boldsymbol{I})

ここで時刻tの状態を一気にサンプリングできて、
q(x_t|x_0)=N(x_t;\sqrt{\displaystyle{\prod_{i=1}^{t}\alpha_i}}, (1-\displaystyle{\prod_{i=1}^{t}\alpha_i})\boldsymbol{I})
となります。

この証明ですが、大抵の場合q(x_1|x_0)から始まる帰納法もしくは\cdotsでごまかすやつを使って証明します。どっちにしろ条件の方のx_0は固定します。
まあそれはそれでいいんですが、最近時刻tへ1回でサンプリングをするのは、離れたタイムステップ間でサンプリングすることの特殊ケースでしかないことに気づきました。
より一般的な式で証明した方が後々分かりやすくない?ということでやってみます。
つまりq(x_t|x_{t-k})kを動かすことで帰納法をやってみます。

離れたタイムステップ間のサンプリングにも対応する、より一般的な結果はこちらになります。
q(x_t|x_{t-n})=N(x_t;\sqrt{\displaystyle{\prod_{i=t-n+1}^{t}\alpha_i}}, (1-\displaystyle{\prod_{i=t-n+1}^{t}\alpha_i})\boldsymbol{I})

見づらいので、\alpha_{t|t-k}=\displaystyle{\prod_{i=t-k+1}^{t}\alpha_i}とします。
すると\alpha_{t|t-k}\alpha_{t-k} =(\displaystyle{\prod_{i=t-k+1}^{t}\alpha_i})\alpha_{t-k} = \displaystyle{\prod_{i=t-k}^{t}\alpha_i} = \alpha_{t|t-(k+1)}であることに注意しといてください。

帰納法で証明します。
証明したいことは、以下の通りです。
q(x_t|x_{t-n})=N(x_t;\sqrt{\alpha_{t|t-n}}x_{t-n}, (1-\alpha_{t|t-n})\boldsymbol{I})
n=1のとき、\alpha_{t|t-1}=\displaystyle{\prod_{i=t}^{t}\alpha_i}=\alpha_tであるため定義通りになります。
n=kのとき、題意が成り立つとすると、x_tは以下のようにサンプリングできます。


\begin{eqnarray}
x_t &= &\sqrt{\alpha_{t|t-k}}x_{t-k}+\sqrt{1-\alpha_{t|t-k}}\epsilon_1 \\
&=& \sqrt{\alpha_{t|t-k}}(\sqrt{\alpha_{t-k}}x_{t-(k+1)} + \sqrt{1-\alpha_{t-k}}\epsilon_2)+ \sqrt{1-\alpha_{t|t-k}}\epsilon_1 \\
& = & \sqrt{\alpha_{t|t-k}}\sqrt{\alpha_{t-k}}x_{t-(k+1)}  + \sqrt{\alpha_{t|t-k}}\sqrt{1-\alpha_{t-k}}\epsilon_2 + \sqrt{1-\alpha_{t|t-k}}\epsilon_1 \\
& = & \sqrt{\alpha_{t|t-(k+1)}}x_{t-(k+1)}  + \sqrt{\alpha_{t|t-k}}\sqrt{1-\alpha_{t-k}}\epsilon_2 + \sqrt{1-\alpha_{t|t-k}}\epsilon_1
 \end{eqnarray}

これでq(x_t|x_{t-(k+1)})の平均が\sqrt{\alpha_{t|t-(k+1)}}となることはわかりました。
分散については正規分布の再生性から、二項と三項の分散を足すだけでよくて、

\alpha_{t|t-k}(1-\alpha_{t-k})+1-\alpha_{t|t-k}=1-\alpha_{t|t-(k+t)}となります。
よって、q(x_t|x_{t-(k+1)})=N(x_t;\sqrt{\alpha_{t|t-(k+1)}}x_{t-(k+1)}, (1-\alpha_{t|t-(k+1)})\boldsymbol{I})

帰納法のあれがあーいうやつにより、証明完了です。

さんぷらーについて

 各サンプラーの意味とかがなんとなく分かりたくて書いたものです。SDEやODEの導出に関する話はでてきません(分からんし)。

拡散過程の定義

サンプラーによって使われている文字の意味が違うので、ここでは文字をあわせていいきたいと思います。そのため論文の式そのままにできないので間違っている可能性があります。
拡散過程は分散保存型と分散発散型の二種類に分かれています。

分散保存型

元の画像を弱めながらノイズを加えていく方法です。
 x_{t} = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon

二つの係数の二乗和をとると1になります。つまり分散がどの時刻でも固定されます。

この拡散過程は以下のようにx_tを単独でサンプリングすることができます。
 x_{t} = \sqrt{\bar{\alpha}_t}x_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon
ここで、\bar{\alpha}_t = \displaystyle{\prod_{s=1}^{t}\alpha_s}です。

この説明には\betaも使うことが多いです。ば~の有無にかかわらず、\alpha = 1-\betaです。\betaを使った方がシンプルになるときは\betaを使います。

分散発散型

元の画像をそのままにして、ノイズを加えていく方法です。
x_t = x_{0} + \sigma_t \epsilon

Stable Diffusionは学習時に分散保存型を採用していますが、分散発散型であると仮定して生成することも可能です。
\sigma_t = \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}}とすれば、\sqrt{\bar{\alpha}_t}x_t = \sqrt{\bar{\alpha}_t}x_{0}+ \sqrt{1-\bar{\alpha}_t}\epsilonとスケーリングすることで分散保存型の式になるので、それをUNetに入力すればよいです。出力は係数が掛けられていない元のノイズなので、スケーリングは必要ありません。

離れたタイムステップ間のサンプリング

離れた時刻t-hから tへのサンプリングもできます。
x_{t} = \sqrt{\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-h}}}x_{t-h} + \sqrt{1-\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-h}}}\epsilon

\bar{\alpha}は累積積なので、h = 1だったら、\alpha_tがそのままでてきますね。
1ステップ分の式を離れた時刻に対応する式に置き換えたい場合、\alpha_t\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-h}}に置き換えればよいだけです。またh自然数ではなく実数にすることができます(時刻を離散時間から連続時間にする)。UNetの時刻埋め込みはうまいこと連続になるよう設定されているので、離散的な時刻しか学習していなくても問題ないらしい。

逆拡散過程

逆拡散過程では、加えられたノイズをモデルに予測させ、その結果を使って次のステップの画像を推定します。その推定の仕方はサンプラーによって異なります。

ノイズの予測を\epsilon(x_t, t)とすると、ノイズのない画像の予測\hat{x_0}は、
\hat{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon(x_t, t))
となります。

モデルの予測結果から導出する変数は、\hat{x}_0, \hat{\epsilon}など帽子をつけてあげます。

画像生成の流れ

Stable Diffusionの画像生成はdiffusersでは以下のようなコードになっています。

# timestepの設定
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps

# 初期ノイズの作成
latents = torch.randn(batch_size, 4, height // 8, width // 8)

# 分散発散型の場合simga_1000をかける
latents = latents * self.scheduler.init_noise_sigma

# ノイズ除去ループ
for i, t in enumerate(timesteps):
    #分散発散型の場合に入力を分散保存型に置き換える
    latent_model_input = self.scheduler.scale_model_input(latents, t)

    #ノイズ予測
    noise_pred = self.unet(latent_model_input , t)

    # 次のステップの潜在変数を推定
    latents = self.scheduler.step(noise_pred, t, latents).prev_sample

今回の話の主題ではない条件付き生成に関わる部分などは省略しています。時刻t \in [1, 2, \cdots, 1000]を設定したステップ数n分取り出して、\tau_1 , \cdots, \tau_nとして、x_{\tau_n}\to x_{\tau_{n-1}} \to \cdots \to x_{\tau_{1}}と処理を繰り返します。このとき現在の画像から与えられたノイズを予測する役割を持つのがUNetで、次の時刻の状態を予測する役割を持つのがサンプラーになります。

微分方程式サンプラー

DDPM、DDIMといった微分方程式を使わないサンプラーは、以下のように平均と分散を推定することで次の画像を予測します。
x_{t-1} = \mu_\theta(x_t,t) + \beta \epsilon
ただし分散は学習せず固定します。

DDPM

DDPMSamplerは最初にできたサンプラーです。

DDPM1
\hat{x}_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(\hat{x}_{t} - \frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}}\hat{\epsilon}) + \sqrt{1-\bar{\alpha}_{t}}\epsilon

またバリエーションとして、DDPMの導出中にでてくるq(x_{t-1}|x_t,x_0)からでてくる平均と分散を使ってx_{t-1}を予測する式もあります。

DDPM2
2. \hat{x}_{t-1} = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{\bar{\beta}_t}\hat{x_0}+\frac{\sqrt{\alpha_t}\bar{\beta}_{t-1}}{\bar{\beta}_t}x_t + \sqrt{\frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}}\beta_t}\epsilon

よくわかりませんが、diffusersではこっちを使っているようです。後述のDDIMの特殊ケースになります。

DDIM

拡散過程は今まで加えられてきたノイズと、次に加えるノイズが全くの無関係でした。つまりx_tが依存するのはx_{t-1}のみでした。逆に加えるノイズを全時刻で固定することを考えると、x_tx_0のみに依存するようになります。DDIMではその間をとるような拡散過程を考えます。するとDDPMとDDIMはうまいこと最適解が一致するため、DDPMで学習したモデルに対してDDIMを根拠にした生成を行えます。DDIMSamplerは1000ステップ必要だったDDPMSamplerに比べて数十ステップくらいで生成できるようになりますた。

サンプリングは、

DDIM
1. はいぱら\etaに対して、s_t = \eta\sqrt{\frac{\bar{\beta}_{t-1}}{\bar{\beta}_t}\beta_t}とする
2. \hat{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_{0} + \sqrt{\bar{\beta}_{t-1}-{s_t}^2}\hat{\epsilon} + s_t\epsilon

(1)\eta=1のとき、DDPMの二つ目の式と同じになります。
計算
第二項は

\begin{eqnarray}
\sqrt{\bar{\beta}_{t-1}-\frac{\bar{\beta}_{t-1}}{\bar{\beta}_t}\beta_t}\hat{\epsilon} & = & \sqrt{\frac{(\bar{\beta}_{t}-\beta_t)\bar{\beta}_{t-1}}{\bar{\beta}_t}}\hat{\epsilon} \\
& = & \sqrt{\frac{(\bar{\beta}_{t}-\beta_t)\bar{\beta}_{t-1}}{\bar{\beta}_t}}\frac{x_t-\sqrt{\bar{\alpha_t}}\hat{x}_0}{\sqrt{\hat\beta_{t}}} \\
& = & \frac{\sqrt{(\bar{\beta}_{t}-\beta_t)\bar{\beta}_{t-1}}}{\bar{\beta}_t}(x_t-\sqrt{\bar{\alpha_t}}\hat{x}_0) \\
& = & \frac{\sqrt{(1-\bar{\alpha}_{t}- (1- \alpha_t))\bar{\beta}_{t-1}}}{\bar{\beta}_t}(x_t-\sqrt{\bar{\alpha_t}}\hat{x}_0) \\
& = & \frac{\sqrt{\alpha_t(1-\bar{\alpha}_{t-1})\bar{\beta}_{t-1}}}{\bar{\beta}_t}(x_t-\sqrt{\bar{\alpha_t}}\hat{x}_0) \\
& = & \frac{\sqrt{\alpha_t\bar{\beta}_{t-1}\bar{\beta}_{t-1}}}{\bar{\beta}_t}(x_t-\sqrt{\bar{\alpha_t}}\hat{x}_0) \\
& = & \sqrt{\alpha_t}\frac{\bar{\beta}_{t-1}}{\bar{\beta_{t}}}x_t - -\sqrt{\alpha_t\bar{\alpha_t}}\frac{\bar{\beta}_{t-1}}{\bar{\beta_{t}}}\hat{x_0}\\
\end{eqnarray}

x_tの係数に着目すると、\sqrt{\alpha_t}\frac{\bar{\beta}_{t-1}}{\bar{\beta_{t}}}となりDDPMの二番目の式と一致しました^^

\hat{x_0}の係数に着目すると、-\sqrt{\alpha_t\bar{\alpha_t}}\frac{\bar{\beta}_{t-1}}{\bar{\beta_{t}}}となり
第一項の\hat{x_0}の係数と足すと、
\sqrt{\bar{\alpha}_{t-1}}-\sqrt{\alpha_t\bar{\alpha_t}}\frac{\bar{\beta}_{t-1}}{\bar{\beta_{t}}}=\sqrt{\bar{\alpha}_{t-1}}\frac{\bar{\beta}_{t}-\alpha_t\bar{\beta}_{t-1}}{\bar{\beta}_{t}}=\sqrt{\bar{\alpha}_{t-1}}\frac{1-\bar{\alpha}_{t}-\alpha_t + \alpha_t\bar{\alpha}_{t-1}}{\bar{\beta}_{t}}
=\sqrt{\bar{\alpha}_{t-1}}\frac{\beta_{t}}{\bar{\beta}_{t}}
DDPMの二番目の式の係数と一致しました^^
ノイズの係数は計算せずとも一致していることが分かります^^

(2)\eta=0のとき、生成過程でランダムノイズ部分がなくなり、初期ノイズから決定的に生成するようになります。つまり全時刻で同じノイズを与えたと考えた時の生成法となります。
\hat{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_{0} + \sqrt{\bar{\beta}_{t-1}}\hat{\epsilon}
ニューラルネットワークによるノイズ予測が完全に正しければ、各ステップにおけるノイズ予測は定数になりステップをいくつスキップするかに関わらず同じ結果が得られます。現実的にはそんなネットワークは作れませんが(作れたとしても面白くない)、ある程度予測精度が高ければ、1ステップはさすがに無理だけど、数十ステップくらいで生成できるようになる、というイメージっぽいです(あってるか分からん)。
ここまで長々と説明してなんですが、どの実装でもデフォルトだと\eta=0になります。他のパターンなんて知っててもどうしようもないですね。

PNDM(PLMS)

PNDMはDDIMの更新式をODE化して初期値問題を解く方法です。非微分方程式じゃないんですが、DDIMの拡張みたいなんでここに入れました。式を厳密に入れるととんでもないことになるので省略しますが、最初の3ステップはるんげくったを使って計算し、その後は過去3ステップ分の予測結果を利用した多段法を使います。
更新式の計算は以下の通りです。DDIMの\eta=0の式からはじまります。

\begin{eqnarray}
\hat{x}_{t-h} &=& \sqrt{\bar{\alpha}_{t-h}}\hat{x}_{0} + \sqrt{\bar{\beta}_{t-h}}\hat{\epsilon} \\
 &=& \sqrt{\bar{\alpha}_{t-h}}\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1-\bar{\alpha}_t}\hat{\epsilon}) + \sqrt{\bar{\beta}_{t-h}}\hat{\epsilon} \\
 &=&  \frac{\sqrt{\bar{\alpha}_{t-h}}}{\sqrt{\bar{\alpha}_t}}x_t -  \frac{\sqrt{\bar{\alpha}_{t-h}(1-\bar{\alpha}_t)}- \sqrt{\bar{\alpha}_t(1-\bar{\alpha}_{t-h})}}{\sqrt{\bar{\alpha}_t}}\\
 &=&  \frac{\sqrt{\bar{\alpha}_{t-h}}}{\sqrt{\bar{\alpha}_t}}x_t -  \frac{\bar{\alpha}_{t-h}(1-\bar{\alpha}_t)- \bar{\alpha}_t(1-\bar{\alpha}_{t-h})}{\sqrt{\bar{\alpha}_t}(\sqrt{\bar{\alpha}_{t-h}(1-\bar{\alpha}_t)}+ \sqrt{\bar{\alpha}_t(1-\bar{\alpha}_{t-h})})}\\
 &=&  \frac{\sqrt{\bar{\alpha}_{t-h}}}{\sqrt{\bar{\alpha}_t}}x_t -  \frac{\bar{\alpha}_{t-h}-\bar{\alpha}_t}{\sqrt{\bar{\alpha}_t}(\sqrt{\bar{\alpha}_{t-h}(1-\bar{\alpha}_t)}+ \sqrt{\bar{\alpha}_t(1-\bar{\alpha}_{t-h})})}\\
\end{eqnarray}
最後の2行は分子の有理化?っぽい計算であまり意味はないんですが、微分方程式としてちゃんと導出する場合は最後の行が出てくるんだと思います。

微分方程式サンプラー(dpm-solver以外)

拡散モデルの逆拡散過程は微分方程式としても表現できます。確率微分方程式(SDE)のまま計算する方法と、常微分方程式(ODE)として変形したものを計算する方法があるみたいですがよく分かりません。

ODEは分散発散型にすると式が簡単になります。そのためこのカテゴリのサンプラーは分散発散型を前提にしています。
肝心の微分方程式は以下の通りです。
dx = \epsilon_\theta(x_t,t)\frac{d\sigma}{dt}dt=\hat{\epsilon}d\sigma
画像の微小変化量=ノイズ予測に\sigmaの微小変化をかけたもの、というめちゃくちゃシンプルな式ですね。

Euler

Euler
1. x_{t-h}=x_{t}-dx=x_{t}-\epsilon_\theta(x_t,t)(\sigma_{t}-\sigma_{t-h})
Euler法は1階微分により傾きを求めてその傾き通り更新ということをちびちびとするだけです。勾配降下法に似ていますね。
1階の近似なので早いけど精度はあんまりという感じですね。

diffusersでは少しノイズを加えてx_{t+g}=x_t + \epsilon\sqrt{{\sigma_{t+g}}^2-{\sigma_{t}}^2}としたあとに、x_{t-h}=x_{t+g}-\hat{\epsilon}(\sigma_{t+g}-\sigma_{t-h})とするような実装になっていますが、どの実装でもデフォルトではg=0なんであまり気にしなくてよさそうです。DDIMの\etaと似たような意味なんじゃないかな。

Heun

Heun
1. {x'}_{t-h}=x_{t}-\epsilon_\theta(x_{t},t)(\sigma_{t}-\sigma_{t-h})
2. x_{t-h}=x_{t}-\frac{\epsilon_\theta({x'}_{t-h},t-h)+\epsilon_\theta(x_{t},t)}{2}(\sigma_{t}-\sigma_{t-h})

一階微分だけで二階近似ができる方法です。一度Euler法で目標地点を推定した後、その場所でのノイズ予測との平均をとって更新します。Euler法に比べて計算時間が2倍になります。

DPM-Solver系サンプラー

dpm-solverは拡散モデル専用のソルバーで、時刻tではなく\lambda_t=\log \sqrt{\mathrm{SNR}(t)}を変数としてODE化したものです。SNRは信号対雑音比とよばれ、ノイズがどれほど弱いか(小さくなると強い)を表す指標です。\mathrm{SNR}(t)=\frac{\bar{\alpha}_t}{1-\bar{\alpha}_t}=-\frac{1}{\sigma_t^2}になります。ちなみに\lambda_t = - \log \sigma_tです。

全然分からんが厳密に求められる部分と近似が必要な部分を分離して、近似が必要な部分だけ近似することで精度をあげているらしい。

DPM-Solver

DPM-Solver-1
x_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}x_{t}-\sqrt{\bar{\beta}_{t-1}}(e^{\lambda_{t-1}-\lambda_{t}}-1)\epsilon_\theta(x_t,t)
※離れたタイムステップに関しては、\alpha_t\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-h}}に置き換えればよいだけです。

この更新式はDDIMの\eta=0の場合と全く同じになります。

第一項について、

\begin{eqnarray}
\frac{1}{\sqrt{\alpha_{t}}}x_{t} &=& \frac{1}{\sqrt{\alpha_{t}}}(\sqrt{\bar{\alpha}_t}\hat{x_0}+\sqrt{\bar{\beta}_{t}}\hat{\epsilon}) \\
&=& \sqrt{\bar{\alpha}_{t-1}}\hat{x_0}+\frac{\sqrt{\bar{\beta}_{t}}}{\sqrt{\alpha_{t}}}\hat{\epsilon} \\
 
\end{eqnarray}
第二項について、
\begin{eqnarray}-\sqrt{\bar{\beta}_{t-1}}(e^{\lambda_{t-1}-\lambda_{t}}-1)\hat{\epsilon} &=& -\sqrt{\bar{\beta}_{t-1}}(\frac{e^{\log \sigma_{t}}}{e^{\log \sigma_{t-1}}}-1)\hat{\epsilon}  \\
&=& -\sqrt{\bar{\beta}_{t-1}}(\frac{\sigma_{t}}{\sigma_{t-1}}-1)\hat{\epsilon} \\
&=& -\sqrt{\bar{\beta}_{t-1}}(\frac{\sqrt{1-\bar{\alpha}_{t}}}{\sqrt{\bar{\alpha}_{t}}}\frac{\sqrt{\bar{\alpha}_{t-1}}}{\sqrt{1-\bar{\alpha}_{t-1}}}-1)\hat{\epsilon} \\
&=& -\sqrt{\bar{\beta}_{t-1}}(\frac{\sqrt{\bar{\beta}_{t}}}{\sqrt{\alpha_{t}}\sqrt{\bar{\beta}_{t-1}}}-1)\hat{\epsilon} \\
&=& - \frac{\sqrt{\bar{\beta}_{t}}}{\sqrt{\alpha_t}}\hat{\epsilon} +\sqrt{\bar{\beta}_{t-1}}\hat{\epsilon}  \\
\end{eqnarray}

第一項と第二項を足すと、DDIMの更新式

\hat{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_{0} + \sqrt{\bar{\beta}_{t-1}}\hat{\epsilon}

になりました。

というわけで1階ではDDIMと同じなので通常使われません。使われるのは2階や3階です。

DPM-Solver-2
1.s=t(\frac{\lambda_{t-1}+\lambda_{t}}{2})
2.x_{s}=\frac{\sqrt{\bar{\alpha}_{s}}}{\sqrt{\bar{\alpha}_{t}}}x_{t}-\sqrt{\bar{\beta}_{s}}(e^{\frac{\lambda_{t-1}-\lambda_{t}}{2}}-1)\epsilon_\theta(x_{t},t)
3.x_{t-1}=\frac{\sqrt{\bar{\alpha}_{t-1}}}{\sqrt{\bar{\alpha}_{t}}}x_{t}-\sqrt{\bar{\beta}_{t-1}}(e^{\lambda_{t-1}-\lambda_{t}}-1)\epsilon_\theta(x_{s},s)
2階は中点法を使います。まず次のステップまでの中点(対数平方根SNR基準)を1階のDPM-solverで推定し、その地点でのノイズ予測結果を利用して次のステップの画像を1階のDPM-solverで推定します。

3階はよ―わからんけど同じようなことを3回にするんでしょう。

モデルの計算回数や時刻に応じて何階にするかを自動で決定する方式(ComfyUIのdpm-fast)やステップサイズを適応的に変更する方法(dpm-adaptive)も提案されています。
また後述しますがノイズの強さを調整することも考えられています。

DPM-Solver++

dom-solverはCFGを使うと、精度が下がるという弱点があるらしく、それを改善するために予測ノイズではなく予測した元画像を利用する方式がDPM-Solver++です。

DPM-Solver++-1
x_{t-1}=\sqrt{\frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}}}x_{t}-\sqrt{\bar{\alpha}_{t-1}}(e^{\lambda_{t}-\lambda_{t-1}}-1)\hat{x}

最適解自体はdpm-solverと同じですが、どの部分を近似するかというところが違うらしい。CFGを使う場合、ノイズ予測が学習時と生成時で一致しないので、ノイズ予測結果ではなく元の画像の予測結果を使ってるみたいな感じなんかね。

これも2階や3階があります。2階は論文にはより一般化された記述がなされていましたが、ComfyUIの実装をみると中点法を使ってるっぽいのでdpm-solverと同様ということで省略します。ただしこちらは過去のステップの情報を使うことで、1回分の計算で2階近似できるようになるMulti Steps(M)バージョンがあります。

DPM-Solver++-2M
1. h_{t-1} = \lambda_{t-1}-\lambda_{t}
2.r_{t-1} = \frac{h_{t}}{h_{t-1}}
3.D_{t-1} = (1+\frac{1}{2r_{t-1}})x_\theta(x_{t-1},t-1)-\frac{1}{2r_{t-1}}x_\theta(x_{t},t)
4.x_{t-1}=\sqrt{\frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}}}x_{t}-\sqrt{\bar{\alpha}_{t-1}}(e^{-h_{t-1}}-1)D_{t-1}
Adams-Bashforth法というやつらしいよ。式中にモデルの予測が2つでてきますが、1つは過去の情報を記憶しておけばよいので1回の計算で済みます。これに対して中点法など過去のステップの情報を使わない方法はSingle stepと呼ばれます。2SとかのSのことです。

DPM-solver-SDE

SDE版のソルバーも提案されています。まあ全然わかりませんけど。

DPM-Solver-1-SDE
x_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}x_{t}-2\sqrt{\bar{\beta}_{t-1}}(e^{\lambda_{t-1}-\lambda_{t}}-1)\hat{\epsilon}+\sqrt{\bar{\beta}_{t-1}}\sqrt{(e^{2(\lambda_{t-1}-\lambda_{t})}-1)}\epsilon

DPM-Solver++-1-SDE
x_{t-1}=\frac{1}{\sqrt{\beta_{t}}}e^{\lambda_{t}-\lambda_{t-1}}x_{t}+\sqrt{\bar{\alpha}_{t-1}}(1-e^{\lambda_{t}-\lambda_{t-1}})\hat{x}+\sqrt{\bar{\beta}_{t-1}}\sqrt{1-e^{2(\lambda_{t}-\lambda_{t-1})}}\epsilon

各ステップでノイズが加えられるので、決定的な生成になりません。DDIMの\etaをいじるのと似ているのかな。

ancestral sampling

eulerやdpm-solver等で適用できるやつです。1ステップごとに普通より多めにノイズ除去をして、少しノイズを加えるということをします。これもDDIMの\etaを調整するのと似ているのかな。x_tx_0だけでなくx_{t-1}にも依存するようにします(だから先祖サンプリング?)。これによって安定するけど収束しづらくなると意味不明な説明がされています(収束しづらいのに安定するってなんだよ)。

ancestral sampling
1. \sigma_{to} = \sigma_{t-h},\ \  \sigma_{from} = \sigma_t
2. \sigma_{up} = \sqrt{\frac{\sigma_{to}^2}{\sigma_{from}^2}(\sigma_{from}^2-\sigma_{to}^2)}
3. \sigma_{down} = \sqrt{\sigma_{to}^2-\sigma_{up}^2}
4. x_{t(\sigma_{down})}=\mathrm{sampler}(x_{t(\sigma_{from})}, t(\sigma_{from})\to t(\sigma_{down}))
5. x_{t(\sigma_{to})} = x_{t(\sigma_{down})} + \sigma_{up} \epsilon

ノイズスケジュール

ステップをスキップするとき、時刻をどう分けるかという工夫の余地があります。細かい方法を除けば3つの方法があります。
以下の記事が詳しいです。
KSampler ノード - ComfyUI 解説 (wiki ではない)

normal

時刻を普通に等分します。おわり。

karras

Karrasさんが考えたスケジューラーです。\sigmaの7乗根を等分します。normalに比べて中間のステップを飛ばす代わりに最初と最後らへんの精度をあげているようです。webuiでは最後にKarrasがつくサンプラーがこれを採用しています。

exponential

ComfyUIにあります。対数平方根SNRを等分します。dpm-solverで提案されたやつです。ノイズの大きい部分を飛ばしてノイズの小さい部分をより細かくします。

LoRA学習の効率化法?

 LoRA学習を効率化できるかもしれない方法を考えたので説明していきます。簡単に言うと従来はLoRAをdown層とup層に分けて二層を順次計算していましたが、down層とup層を合体して、さらに元の重みにマージしてから計算した方が効率が良くなるかもしれませんという話です。この記事では行列積の計算量とかで単純なアルゴリズムを前提にしていますが、実際はもっと最適化しているはずなので実装とのずれはあると思います。あと分からないことがありますが放置中です。

計算量の考察

LoRAは以下のように計算できます。
 
\begin{eqnarray}
y^{(b,m)} &=& x^{(b,n)}W_{\mathrm{merge}}^{(n,m)} && \cdots (0) \\
&=& x^{(b,n)}(W_{\mathrm{org}}^{(n,m)}+W_{\mathrm{down}}^{(n,r)}W_{\mathrm{up}}^{(r,m)}) && \cdots (1) \\
&=& x^{(b,n)}W_{\mathrm{org}}^{(n,m)} + (x^{(b,n)}W_{\mathrm{down}}^{(n,r)})W_{\mathrm{up}}^{(r,m)} && \cdots (2)
\end{eqnarray}

(n,m)は入出力チャンネル、bはそれ以外の次元の積(画像の縦横やバッチサイズ、トークン長等)になります。

  • (0)式は生成時に使われる場合があります。あらかじめLoRAをモデルにマージしておくことで、計算量増加を防ぎます。
  • (1)式はLoHAの学習で使われます。LoRAの重みを計算して元の重みにマージしてから計算します。
  • (2)式は通常のLoRA学習時に使われます。元の重みで計算した出力と、入力にdown層、up層の順で通した出力を足し合わせます。

この3つの計算量について考えていきます。足し算の数はほぼ無視していいと思うので、掛け算の数だけで比較していきます。
前提として、行列積X^{(a,b)}Y^{(b,c)}の計算の掛け算の数はabcになります。

  • (0)式:bnmです。
  • (1)式:W_{\mathrm{down}}^{(n,r)}W_{\mathrm{up}}^{(r,m)}部分の計算がnrmです。あとは(0)と同じような計算であり、bnmが足されます。
  • (2)式:(x^{(b,n)}W_{\mathrm{down}}^{(n,r)})W_{\mathrm{up}}^{(r,m)}部分の計算がbnr+brm=br(n+m)で、同じくbnmが足されます。


(1)と(2)を比較すると、bnmは同じとして、それ以外の部分はrが掛けられるのはいっしょなので無視すると、nmb(n+m)の大小比較になります。bn,mに比べて十分大きい場合、(1)の方が効率よくなります。
rankは大小関係には影響しませんが、どのくらいの差があるかはrankに比例して大きくなっていきます。

bの大きさについて

bは画像のサイズ、トークン長、バッチサイズに影響します。これは学習設定やモデルの層によって変わっていきます。

  • 学習設定について

バッチサイズや画像のサイズを大きくすればするほどbが大きくなり、(1)式を使った方がよくなっていくと思います。

  • モデルの設計について

Stable DiffusionのUNetは内側の層ほど画像が縮小され、チャンネル数が増えます。つまりbが減っていきn,mは増えていきます。つまりUNetの外側の層ほど(1)式を使った方がよくなります。ちなみにテキストエンコーダはbトークン長77×バッチサイズで固定、(n,m)はだいたい768 or 1024 or 1280なので、バッチサイズ4~8くらいで逆転しそうですね。
入力に応じて適応的に式を選ぶのが一番よさそうです。nmはLoRAを作るときに計算できます。bは入力に応じて逐一計算する必要があります。

VRAM使用量について

この項はちょっと自信ないです。行列積X^{(a,b)}Y^{(b,c)}時の逆伝搬のためのキャッシュはab+bcです。足し算も同じです。

  • (1)式はnb+nm+rn+rm + nm
  • (2)式はnm + nb+nr + rb+rm + 2bm

W_{\mathrm{org}}の勾配に必要な項は無視しています(キャッシュされないはず?)。共通項を除くとnm, rb+2bmの比較になります。これもbが大きいと(1)式が有利になります。

backward側の計算について

この項もちょっと自信ないですが、行列積の逆伝搬の計算量は順伝搬の2倍になるだけです。ただし(2)式ではW_{\mathrm{org}}は勾配不要なので、bnm分減ります。あれえ?じゃあ(1)式効率悪くない?nm(b+2), 2br(n+m)の比較はほとんど後者の方が小さくなりそうです。ただこの影響がどうも実験的には出てこないんですよね。よく分かりません。

3×3の畳み込み層について

3×3畳み込みの場合、nを9倍すればいいだけです(多分)。畳み込みは画像の1チャンネルを9チャンネルに拡張したような計算になります。LoRAの場合down層の入力チャンネル数が9倍されます。

実験

適当にやりました。Paperspace A6000で768画像をバッチサイズ8でsdv2モデルをrank=16のloconでxformersありで320ステップやりました。
(1)式のみを使う:327秒/27197MiB
(1)式をnm \le b(n+m)に基づいて適応:324秒/26749MiB
(1)式をnm \le rb+2bmに基づいて適応:318秒/26593MiB
従来通り(2)式:348秒/27427MiB

従来の計算よりも効率的であることが分かりますた。なんかVRAMを一番節約できそうな式が計算時間的にも一番良い結果になりました。まあ最適化とかを無視した計算をしているので理論通りにはならないのは当然といえば当然ですね。
backward側の計算量が(2)式の方が効率よくなりそうな話についてはよく分かりませんでした...。

LoHAでbackwardを定義する理由

 LoHAとはアダマール積を使ったLoRAの応用手法です。琥珀青葉(KohakuBlueleaf)さんによってStable-diffusionで実装されました。LoHAの実装ではbackwardが定義されています。Pytorchでは特殊な関数を使わない限りbackwardを定義することはありません。しかしLoHAは特殊な関数を使うわけではありません。琥珀青葉さんが簡単に理由を説明してくれているのですが、そもそもbackwardを定義する方法も含めて私はよく知らなかったのでまとめてみます。
元論文
[2108.06098] FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning
琥珀青葉さんの解説
LyCORIS/Algo.md at main · KohakuBlueleaf/LyCORIS · GitHub

長いので結論を先に書く

 Pytorchは、AB\odot CDの計算で逆伝搬のためにAB, CDをキャッシュしてしまいます。しかし実際はA,B,C,Dをキャッシュしておけば十分なのでそうなるように実装されています。

LoHAについて

 LoRAは重み行列の差分\Delta W^{(m,n)}A^{(m,r)}B^{(r,n)}に分解します。パラメータは削減できますが、階数はrで制限されてしまいます。そこでLoHAはアダマール\odotを使って、\Delta W^{(m,n)}=A^{(m,r)}B^{(r,n)}\odot C^{(m,r)}D^{(r,n)}とします。パラメータ数が2倍になっていますが、階数の上界はr^2になります。
詳しくは以前書いた記事を参考にしてください。
LoRAとLoHAの階数を比較する|gcem156

LoHAの計算

 LoRAのABxは二つの行列積に分解して計算できます。そのため\Delta Wを直接計算する必要はありません。しかしLoHAの場合は(AB\odot CD)xは順次計算することはできず、\Delta Wを計算してから、\Delta Wxを計算する必要があります。ここで注目しなければならないのは\Delta W微分です。

Pytorchの自動微分

 誤差逆伝搬法では連鎖律を用いて出力側の層の微分を入力側の層の微分の計算に用いることで、効率よく微分を行います。PytorchのTensorでは四則演算とか行列の積とかの基本的な関数には微分が定義されていて、それらの合成関数であるニューラルネットワークは誤差逆伝搬法を使って自動的に微分が定義されます。
自作関数は以下のように作れるようです。

class sugoi_function_class(torch.autograd.function):
    @staticmethod
    def forward(ctx, x)
        ctx.save_for_backward(x)
        y = sugoi_function(x)
        return y

    @staticmethod
    def backward(dL_dy)
        x, = ctx.saved_tensors
        dy_dx = sugoi_funcition_no_bibun(x)
        dL_dx = dL_dy * dy_dx
        return dL_dx 

 backwardの定義は、まず損失に対する出力の微分(\frac{dL}{dy})が入力され、出力に対する入力微分(\frac{dy}{dx})をかけるという流れです。この際計算に必要な順伝搬時の情報xを記憶しておきます。

以下の記事を参考にしました。
【PyTorch】自作関数の勾配計算式(backward関数)の書き方① - Qiita

LoHAの微分

 LyCORISでは\Delta Wの計算式及びbackwardがここで実装されています。
実装は以下のようになってます。(こぴぺ)

class HadaWeight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
        ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
        diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
        return diff_weight

    @staticmethod
    def backward(ctx, grad_out):
        (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
        grad_out = grad_out * scale
        temp = grad_out*(w2a@w2b)
        grad_w1a = temp @ w1b.T
        grad_w1b = w1a.T @ temp

        temp = grad_out * (w1a@w1b)
        grad_w2a = temp @ w2b.T
        grad_w2b = w2a.T @ temp
        
        del temp
        return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None

 順伝搬の出力がAB\odot CDであるため、backwardへの入力は\frac{dL}{d(AB\odot CD)}です。
行列ABに注目すると、アダマール積に関する微分は以下の式で表されます。
 \frac{dL}{d(AB)}=\frac{dL}{d(AB\odot CD)}\odot\frac{d(AB\odot CD)}{dAB} = \frac{dL}{d(AB\odot CD)} \odot CD
要素ごとの積なので、連鎖律も行列積ではなく要素ごとの積になります。
 行列積に関する微分は以下の式で表されます。
 \frac{dL}{dA} =  \frac{dL}{dAB} B^T,\frac{dL}{dB} = A^T \frac{dL}{dAB}

よって、
\frac{dL}{dA} = (\frac{dL}{d(AB\odot CD)} \odot CD) B^T
\frac{dL}{dB} = A^T (\frac{dL}{d(AB\odot CD)} \odot CD)
C,Dに関しても同様にして、
\frac{dL}{dC} = (\frac{dL}{d(AB\odot CD)} \odot AB) D^T
\frac{dL}{dD} = C^T (\frac{dL}{d(AB\odot CD)} \odot AB)
となります。

 実装において、順伝搬時に保存する行列はA,B,C,Dになっています。
(scaleについては本質的ではないので省略。)

Pytorchの実装をそのまま使う場合

 AB\odot CD微分は、アダマール積と行列積の合成関数になっています。この二つは特殊な関数ではないので、Pytorch上で当然backwardが定義されており、自動微分可能です。それではどうしてbackwardを定義する必要があるのか確認していきましょう。

import torch
from torchviz import make_dot
m=10
n=5
r=2

A = torch.randn((m,r))
B = torch.randn((r,n))
C = torch.randn((m,r))
D = torch.randn((r,n))
for mat in [A,B,C,D]:
    mat.requires_grad = True

AB = A@B # 後の確認用にABを残しておく
CD = C@D
W = AB * CD

image = make_dot(W)
image.format = "png"
image.render("test")

シンプルな実装で重みを作ってみました。このときの計算グラフをみてみると以下のような感じになりました。

 グラフなんか見なくても想像できますが、4つの行列から2つずつ行列積をとりできた2つの行列でアダマール積をとっています。ここで重要なのはアダマール積(MulBackward0)の部分です。Pytorchの実装ではアダマール積をとるとき逆伝搬に必要な順伝搬の入力である、AB, CDの二つを自動的にキャッシュしてしまいます。しかしその上の行列積部分ではA,B,C,Dをキャッシュしているため、AB, CDはキャッシュせずに計算可能です。実際LyCORISの実装ではそうしていますね。ちゃんとキャッシュされているかは以下の実装で確かめられます。

print(W.grad_fn._saved_self is AB)
print(W.grad_fn._saved_other is CD)
print(W.grad_fn.next_functions[0][0]._saved_self is A)
print(W.grad_fn.next_functions[0][0]._saved_mat2 is B)
print(W.grad_fn.next_functions[1][0]._saved_self is C)
print(W.grad_fn.next_functions[1][0]._saved_mat2 is D)

LoHAは行列分解する手法なので、A,B,C,Dに対してAB,CDのサイズが非常に大きくなります(それぞれ学習対象パラメータと同じサイズになる)。そのためAB,CDをキャッシュしてしまうのはてりぶるなしちゅえーしょんです。そのためbackwardを定義しているんですね。

実装を比較してみる

 sdv2のunetのみで768画像、バッチサイズ1、xformersあり、bfloat16でrank=16のconvを含むLoHAでVRAM使用量を確認してみました。

元の実装を傷つけずに比較するため、以下のようなclassを作ってみました。

class HadaWeight:
    @staticmethod
    def apply(orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
        diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
        return orig_weight.reshape(diff_weight.shape) + diff_weight

 結果ですが、元実装の9.3GBから12.3GBにまでVRAM使用量が増加しました。理論的には学習対象パラメータの2倍になります。UNetの重みは3.4GBですが、bfloat16なのとLoHAの学習対象に含まれないパラメータもあることを考えると3GBの増加というのは、だいたいあってそうですね。ちなみに計算時間の方はほとんど変わらないようです。キャッシュをやめたところでしょぼい行列計算が増えるだけなのであまり変わらないんでしょうね。ついでにgradient_checkpointingを有効にして比較してみましたが、VRAM使用量は全く変わりませんでした。これはgradient_checkpointingが順伝搬でのキャッシュを省略する手法なので、自然な結果ですね。

おわり。

SD XLの損失関数について確認

SD XLの実装において、損失計算のアルゴリズムが異なり困惑したので、確認してみます。最終的には今までの損失と同値になることが分かりました。
時刻t\in(0,\cdots,T)の潜在変数をx_t、ノイズを\epsilon、UNetをfとします。diffusion modelでは、拡散過程tステップは以下のようにあらわされます。
x_t = \sqrt{\bar{\alpha_t}}x_{0} + \sqrt{1-\bar{\alpha_t}}\epsilon
ここで\bar{\alpha_t}はスケジューラが持っているハイパーパラメータであり、SDv1, SDv2, SD XLで数値は多分変わりません。(deepfloyd ifは違ったりする)

損失は、UNetがノイズを予測するよう学習させるため、ノイズと予測ノイズの二乗誤差になります。
L = \| \epsilon - f(x_t) \|^2
(※sdv2-768系はちょっと違います。)
diffusersによる訓練を超簡易的に書くと、

for x_0, c in dataset: # 実画像や条件付けをサンプリング
    t = torch.randint(0,T)
    noise = torch.randn_like(x_0)
    x_t = scheduler.add_noise(x_0, t)
    noise_pred = unet(x_t, t, c)
    loss = mse(noise, noise_pred)
    ...

となりますが、それに対してSD XLでの損失計算は以下のような流れになっているっぽいです。

for x_0, c in dataset: # 実画像や条件付けをサンプリング
    t = torch.randint(0,T)
    noise = torch.randn_like(x_0)
    sigma = scheduler.get_sigmas(t)
    x_t_pre = x_0 + noise * sigma
    x_t /= (sigma ** 2 + 1) ** 0.5
    noise_pred = unet(x_t, t, c)
    x_0_pred = x_t_pre - noise_pred * sigma 
    loss = mse(x_0, x_0_pred) / (sigma ** 2)
    ...

違いとして、sigmaという新しい値がでてきたり、最後のlossがノイズではなく元の潜在変数の二乗誤差になっています。
sigmaは以下の数式で表されます。
\sigma_t = \frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}} = \frac{1}{\sqrt{\mathrm{SNR}(t)}}
ここで、\mathrm{SNR}(t)とは信号対雑音比のことで、信号の分散をノイズの分散で割ったものです。
それではまず両コードのx_tが同じであることを確認します。
SD XL側の実装において、
x_{t_{pre}} = x_0 + \epsilon\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}}
であり、次に(sigma ** 2 + 1) ** 0.5ですが、これは数式にすると、
 \sqrt{(\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}})^2 + 1} = \sqrt{(\frac{1-\bar{\alpha_t}}{\bar{\alpha_t}}) + 1} = \frac{1}{\sqrt{\bar{\alpha_t}}}
となります。

これが割られるので、x_0に掛けられる係数は\sqrt{\bar{\alpha_t}}、noiseに掛けられる係数は\frac{\sqrt{1-\bar{\alpha_t}}}{\sqrt{\bar{\alpha_t}}}\sqrt{\bar{\alpha_t}} =\sqrt{1-\bar{\alpha_t}}
となり、ちゃんとx_tと同値になることが分かります。unetがノイズを予測できていれば、x_0_predは元のx_0に近づきます。損失はその二つの平均二乗誤差です。ただし損失に\frac{1}{\sigma^2}=\mathrm{SNR}(t)が掛けられています。これは以前書いた記事にちょうど関連する話であり、\|\epsilon - \hat{\epsilon}\|^2=\mathrm{SNR}(t)\|x_0 - \hat{x_0}\|^2という関係から最終的な損失はノイズの二乗誤差になることが分かります。

 何でこんな回りくどいやり方にしたの?と思われるかもしれませんが、どうやらEulerなどの微分方程式ソルバー系の実装に合わせているようですね。まあこの辺り私はよく分かってません。