勾配降下党青年局

万国のグラーディエントよ、降下せよ!

Stable diffusion 2.0で使われた漸進的蒸留をなんとなく理解したような気がする可能性があるかもしれない。

 SD2.0では、512×512の画像で普通に学習した後、768×786の画像を下記の論文による手法で学習したそうです。ざっくりと説明すると、samplerのステップ数を半分にしても生成できるようにする方法です。それにもかかわらず、Stability AIがステップ数を減らしてもうまく生成できるようになったと謳っていないのがよく分からないです。

追記:たぶんv_predictionを使っただけで蒸留は使ってないんだろうね

arxiv.org

  英語はよくわからないので、数式とにらめっこします。

https://arxiv.org/abs/2202.00512

 左が普通のノイズ推定で、右がこの論文の方法ですね。左から読む。texなんて久しぶりに使うなあ。めんどくさいので太字とかローマン体とかなし。細かいところ違うかもしれないけど全体的な流れはあってるはず。。。

通常のdiffusion model

\hat{x}_\theta (z_t)が、モデルの推論結果を表します。中に入れてるのがノイズ込みの画像(latent diffusionの場合潜在変数)ですね。学習ループを見ていきます。

xが学習用画像で、tが時間(何ステップ目か)で\epsilonがノイズですね。画像にノイズを加えたものが、z_t = \alpha _t x + \sigma _t \epsilonです。 \alpha _t \sigma _tはハイパーパラメータです。Stable diffusion 2.0では、\beta _tを0.00085から0.0120まで線形に増やして、\alpha _t = \sqrt{1-\beta _t},\  \sigma _t = \sqrt{\beta _t}としているはずです。targetは、xそのものです、次の二行は理解していない本質的じゃなさそうなのでカット。L_\thetaが実際のノイズと推定したノイズの二乗誤差であり、これの勾配を降りればいい感じですね。

漸進的蒸留

右について、緑で塗られているところがが普通と違うところです。まず教師モデルとかいうのが出てきてます。蒸留なので、当たり前ですが先生が必要です。SD2.0では512×512の画像を学習させたものを先生にしています。教師がいるので、生徒もいます。生徒が学習対象のモデルです。教師モデルのパラメータを\eta、生徒モデルのパラメータを\thetaで表します。とりあえずwhileループを見ていきます。生徒モデルが教師モデルでの2ステップ分のノイズ除去を1ステップでできるようにすることを目的にしています。x,\ z_t,\ \epsilonは左と同じです。 tはなんか難しく書いてあるけど同じようなもんでしょ。左の画像では、targetはxそのものでしたが、今回は教師モデル2回分のノイズ除去をしたものがtargetになります。# 2stepなんたらからの説明。

t'が0.5ステップ前、t''が1ステップ前です。教師モデル的にはそれぞれ1ステップ前、2ステップ前になります。(0.5ステップ前、1ステップ前で統一します。)

教師モデルによる0.5ステップ前の推定画像は\hat{x}_\eta (z_t)です。

よって0.5ステップ前のノイズ \epsilon _{t'}\ =\frac{1}{\sigma_t} (z_t - \alpha _t \hat{x}_\eta (z_t))になります。

そのため、教師モデルによる0.5ステップ前でのノイズ付き画像が、 z_{t'} = \alpha_{t'}\hat{x}_\eta (z_t) + \sigma _{t'}\epsilon_{t'} = \alpha_{t'}\hat{x}_\eta (z_t) + \frac{\sigma _{t'}}{\sigma_t} (z_t - \alpha _t \hat{x}_\eta (z_t))

同様の計算で、教師モデルによる1ステップ前でのノイズ付き画像z_{t''}が求まります。

一方生徒モデルによる1ステップ前でのノイズ付き画像は、教師モデルとちがってt \to t''と1ステップで推論するので、

 \hat{z_{t''}} = \alpha_{t''}\hat{x}_\theta (z_t) + \frac{\sigma _{t''}}{\sigma_t} (z_t - \alpha _t \hat{x}_\theta (z_t))となります。

教師モデルの2ステップと生徒モデルの1ステップを一致させるには、\hat{z_{t''}}=z_{t''}となればよいので、生徒モデルの出力に求めるtarget \hat{x}は、

 \hat{z_{t''}} = \alpha_{t''}\hat{x}+ \frac{\sigma _{t''}}{\sigma_t} (z_t - \alpha _t \hat{x})=z_{t''}

 (\alpha_{t''}- \alpha _t\frac{\sigma _{t''}}{\sigma_t})\hat{x} = z_{t''} - \frac{\sigma _{t''}}{\sigma_t}z_{t}

変形すれば上の画像の通りのtargetになります。

あとは普通のやつと同じですね。

whileループが収束したら、出来上がった生徒が次世代の教師になり、ステップ数を半分にして同じことをやる・・・ということを繰り返してどんどん蒸留していきます。Stable diffusion 2.0ではこのwhileループを何回やったのか書いてないのでよくわからない・・・。

ファインチューニングの実装

 こんな話しをした理由は、下記記事のSD2.0版DreamBoothでやっていることを理解したかったからです。普通のDreamBoothなので蒸留ではありません。

note.com

 見たまま編集モードだとコードの色付けできなくてわろたー。

targetの計算が、target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents

というふうになってます。

 noiseからlatentを引いたのがtarget?なんだこりゃ。これってようするに、

 \hat{y} = \alpha _t \epsilon - \sigma _t xをtargetにしてるということですよね。うーんなんでかわかりません。そもそもノイズと潜在変数をごちゃまぜしたものを返してどうやって潜在変数だけとりだすんや?

 

またnnablaの動画で復習だなこりゃ

 

つづくかもしれない・・・

 

追記:

 \phi = argtan(\alpha / \sigma)とすると、\alpha = \cos (\phi), \sigma = \sin (\phi)となり、

 

v = \frac{d}{d\phi}z=\frac{d}{d\phi}\cos (\phi)x +\frac{d}{d\phi}\sin (\phi )\epsilon=\cos (\phi)\epsilon - \sin (\phi)x

 

をtargetとするのがv_predictionのようですね。

 

ごちゃごちゃやると

 

x = \sin (\phi)z - \cos (\phi)vになるようです。

 

何でこうした方がいいのかよく分かりません