ComfyUI上でDDIMの実装が見たくて探していたのですが、こんな感じでした。あれ・・・Eulerを呼び出してるだけ・・・?というわけで確認していきます。
以下の記事を前提とする。
さんぷらーについて - 勾配降下党青年局
DDIMのデフォルト設定()のとき時刻の更新式は、
一方Eulerはを分散発散型の変数とすると、
となります。ここで上記記事にあるように分散発散型を分散保存型に変換すると、
同じなんですね。全然知りませんでした。
ComfyUI上でDDIMの実装が見たくて探していたのですが、こんな感じでした。あれ・・・Eulerを呼び出してるだけ・・・?というわけで確認していきます。
以下の記事を前提とする。
さんぷらーについて - 勾配降下党青年局
DDIMのデフォルト設定()のとき時刻の更新式は、
一方Eulerはを分散発散型の変数とすると、
となります。ここで上記記事にあるように分散発散型を分散保存型に変換すると、
同じなんですね。全然知りませんでした。
なんか色々増えてきたのでまとめるよ。
LoRAは行列に対して、差分を学習します。このとき、という二つの行列でとすることで、学習対象のパラメータを大幅に削減します。1層目をdown層、2層目をup層と呼びます。学習時は入力に対して、出力をとします。これは計算量的には大きくなりますが、学習対象パラメータが大幅に削減されるメリットの方が大きいです。推論時はとすることで、計算量を増やさずに済みます。実際はLoRAの結果をでスケーリングします。はハイパーパラメータで、の増減でLoRAのスケールが変動することを防ぎ、を変えた時に学習率を変更せずに済むようになっています。
各アルゴリズムの総称として使われることもありますが、一番狭い意味では全結合層及び1×1畳み込み層に適用するものになっています。1×1畳み込み層にはピクセルごとに同じ全結合層を適用しているだけなので簡単に表現できます。
となります。このことからはrankと呼ばれています。上界だけどね。
のLoRAを考えてみましょう。するとLoRAは列ベクトルと行ベクトルの積になります。の各要素をとすると、
です。行基本変形を繰り返せば1行を残し全部0になることがすぐ分かり、になります。
一般ののときは、
みたいな感じになります。上の行目までで行基本変形を行えば三角行列的な形(足し算に対してだけど)にできるので、行目以降を0にできます。よってです。
LoConはフィルターサイズが1×1以外の畳み込み層にもLoRAを適用する方法です。たとえば3×3フィルターの場合、1ピクセルの計算にスポットを当てれば、9マス×入力チャンネル⇒出力チャンネルの全結合層です。3×3畳み込み層の重みはなので数は合いますよね。LoRAの場合はこの変換を9マス×入力チャンネル⇒チャンネル⇒出力チャンネルと二つの変換に分解します。として、3×3畳み込み層にもLoRAを拡張できます。入力を合わせるために、down層のstrideやpaddingは元のモジュールと同じものにします。up層は1×1畳み込みです。
LoHaは二つのLoRAのアダマール積をとる手法です。アダマール積とは単純に行列の要素ごとの積をとる演算です。
です。
LoRAの項目があった形になおすと、
個の足し算同士の掛け算によって、項が個になります。よってになります。パラメータ数が2倍になるのに対して、rankの上界が2乗になります。
LoKRはクロネッカー積を使う手法です。
クロネッカー積とは行列に対して、以下のような行列を返す演算です。
行列と行列に対してクロネッカー積を適用すると、行列になります。
となります。右だけLoRAです。左までLoRAにする設定もあるようです。
クロネッカー積に関して、となるので、rank的にはめちゃくちゃ効率よくなります。ただし重み共有を使っているので、rank的にいいからといって表現力が高いかというと微妙だと思いますけどね。
各モジュールの入力もしくは出力をチャンネルごとにスケーリングします。数式的には対角行列を重みの左側もしくは右側にかけます。パラメータ数は非常に小さいです。入出力両方に適用すれば素朴に表現力があがりそうですが、そういった設定はないようですね。
GLoRAは差分だけでなく、元の重みへの入力に対してもLoRAを適用します。
元論文ではもっと複雑な設計ですが、省略した実装のようですね。これをLoKR版にしたGLoKRもあります。
LoRAと比べて表現力は増えていないんですが、元の重みを利用することで学習が早くなったりするんですかね(よく分かりません)。
複数のrankのLoRAを同時に学習する方法です。LoRAはという形にできます。ここではの列目、はの行目です。DyLoRAではステップごとにランダムにを選び、とします。このように学習すると、任意のrankで、として生成できるようになります。のrankを学習するときに、以下のrankのLoRAに影響を与えないため勾配を無効にしています。rankは1ずつではなくブロック分けすることも可能です。
GroupNormやLayerNorm層も学習対象にします。重みやバイアスの差分を学習するというだけで特に難しいことはありません。
を元の重みと同じサイズの行列にします。つまりやっていることは普通のファインチューニングと同じです。差分をとっておくのはVRAMの浪費でしかないんですが、LyCORIS上でフルファインチューニングを実装するにはそうするしかなかったんでしょう。一応差分に対してドロップアウトを実行できるというメリットはありますけどお。
拡散モデルで定義される、データからノイズを加えていく拡散過程は以下のように定義されています。
ここで時刻の状態を一気にサンプリングできて、
となります。
この証明ですが、大抵の場合から始まる帰納法もしくはでごまかすやつを使って証明します。どっちにしろ条件の方のは固定します。
まあそれはそれでいいんですが、最近時刻へ1回でサンプリングをするのは、離れたタイムステップ間でサンプリングすることの特殊ケースでしかないことに気づきました。
より一般的な式で証明した方が後々分かりやすくない?ということでやってみます。
つまりでを動かすことで帰納法をやってみます。
離れたタイムステップ間のサンプリングにも対応する、より一般的な結果はこちらになります。
見づらいので、とします。
するとであることに注意しといてください。
帰納法で証明します。
証明したいことは、以下の通りです。
のとき、であるため定義通りになります。
のとき、題意が成り立つとすると、は以下のようにサンプリングできます。
これでの平均がとなることはわかりました。
分散については正規分布の再生性から、二項と三項の分散を足すだけでよくて、
となります。
よって、
帰納法のあれがあーいうやつにより、証明完了です。
各サンプラーの意味とかがなんとなく分かりたくて書いたものです。SDEやODEの導出に関する話はでてきません(分からんし)。
サンプラーによって使われている文字の意味が違うので、ここでは文字をあわせていいきたいと思います。そのため論文の式そのままにできないので間違っている可能性があります。
拡散過程は分散保存型と分散発散型の二種類に分かれています。
元の画像を弱めながらノイズを加えていく方法です。
二つの係数の二乗和をとると1になります。つまり分散がどの時刻でも固定されます。
この拡散過程は以下のようにを単独でサンプリングすることができます。
ここで、です。
この説明にはも使うことが多いです。ば~の有無にかかわらず、です。を使った方がシンプルになるときはを使います。
元の画像をそのままにして、ノイズを加えていく方法です。
Stable Diffusionは学習時に分散保存型を採用していますが、分散発散型であると仮定して生成することも可能です。
とすれば、とスケーリングすることで分散保存型の式になるので、それをUNetに入力すればよいです。出力は係数が掛けられていない元のノイズなので、スケーリングは必要ありません。
離れた時刻からへのサンプリングもできます。
は累積積なので、だったら、がそのままでてきますね。
1ステップ分の式を離れた時刻に対応する式に置き換えたい場合、をに置き換えればよいだけです。または自然数ではなく実数にすることができます(時刻を離散時間から連続時間にする)。UNetの時刻埋め込みはうまいこと連続になるよう設定されているので、離散的な時刻しか学習していなくても問題ないらしい。
逆拡散過程では、加えられたノイズをモデルに予測させ、その結果を使って次のステップの画像を推定します。その推定の仕方はサンプラーによって異なります。
ノイズの予測をとすると、ノイズのない画像の予測は、
となります。
モデルの予測結果から導出する変数は、など帽子をつけてあげます。
Stable Diffusionの画像生成はdiffusersでは以下のようなコードになっています。
# timestepの設定 self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # 初期ノイズの作成 latents = torch.randn(batch_size, 4, height // 8, width // 8) # 分散発散型の場合simga_1000をかける latents = latents * self.scheduler.init_noise_sigma # ノイズ除去ループ for i, t in enumerate(timesteps): #分散発散型の場合に入力を分散保存型に置き換える latent_model_input = self.scheduler.scale_model_input(latents, t) #ノイズ予測 noise_pred = self.unet(latent_model_input , t) # 次のステップの潜在変数を推定 latents = self.scheduler.step(noise_pred, t, latents).prev_sample
今回の話の主題ではない条件付き生成に関わる部分などは省略しています。時刻]を設定したステップ数分取り出して、として、と処理を繰り返します。このとき現在の画像から与えられたノイズを予測する役割を持つのがUNetで、次の時刻の状態を予測する役割を持つのがサンプラーになります。
DDPM、DDIMといった微分方程式を使わないサンプラーは、以下のように平均と分散を推定することで次の画像を予測します。
ただし分散は学習せず固定します。
DDPMSamplerは最初にできたサンプラーです。
またバリエーションとして、DDPMの導出中にでてくるからでてくる平均と分散を使ってを予測する式もあります。
よくわかりませんが、diffusersではこっちを使っているようです。後述のDDIMの特殊ケースになります。
拡散過程は今まで加えられてきたノイズと、次に加えるノイズが全くの無関係でした。つまりが依存するのはのみでした。逆に加えるノイズを全時刻で固定することを考えると、はのみに依存するようになります。DDIMではその間をとるような拡散過程を考えます。するとDDPMとDDIMはうまいこと最適解が一致するため、DDPMで学習したモデルに対してDDIMを根拠にした生成を行えます。DDIMSamplerは1000ステップ必要だったDDPMSamplerに比べて数十ステップくらいで生成できるようになりますた。
サンプリングは、
(1)のとき、DDPMの二つ目の式と同じになります。
計算
第二項は
の係数に着目すると、となりDDPMの二番目の式と一致しました^^
の係数に着目すると、となり
第一項のの係数と足すと、
DDPMの二番目の式の係数と一致しました^^
ノイズの係数は計算せずとも一致していることが分かります^^
(2)のとき、生成過程でランダムノイズ部分がなくなり、初期ノイズから決定的に生成するようになります。つまり全時刻で同じノイズを与えたと考えた時の生成法となります。
ニューラルネットワークによるノイズ予測が完全に正しければ、各ステップにおけるノイズ予測は定数になりステップをいくつスキップするかに関わらず同じ結果が得られます。現実的にはそんなネットワークは作れませんが(作れたとしても面白くない)、ある程度予測精度が高ければ、1ステップはさすがに無理だけど、数十ステップくらいで生成できるようになる、というイメージっぽいです(あってるか分からん)。
ここまで長々と説明してなんですが、どの実装でもデフォルトだとになります。他のパターンなんて知っててもどうしようもないですね。
拡散モデルの逆拡散過程は微分方程式としても表現できます。確率微分方程式(SDE)のまま計算する方法と、常微分方程式(ODE)として変形したものを計算する方法があるみたいですがよく分かりません。
ODEは分散発散型にすると式が簡単になります。そのためこのカテゴリのサンプラーは分散発散型を前提にしています。
肝心の微分方程式は以下の通りです。
画像の微小変化量=ノイズ予測にの微小変化をかけたもの、というめちゃくちゃシンプルな式ですね。
diffusersでは少しノイズを加えてとしたあとに、とするような実装になっていますが、どの実装でもデフォルトではなんであまり気にしなくてよさそうです。DDIMのと似たような意味なんじゃないかな。
一階微分だけで二階近似ができる方法です。一度Euler法で目標地点を推定した後、その場所でのノイズ予測との平均をとって更新します。Euler法に比べて計算時間が2倍になります。
dpm-solverは拡散モデル専用のソルバーで、時刻ではなくを変数としてODE化したものです。SNRは信号対雑音比とよばれ、ノイズがどれほど弱いか(小さくなると強い)を表す指標です。になります。ちなみにです。
全然分からんが厳密に求められる部分と近似が必要な部分を分離して、近似が必要な部分だけ近似することで精度をあげているらしい。
この更新式はDDIMのの場合と全く同じになります。
第一項について、
第二項について、
第一項と第二項を足すと、DDIMの更新式
になりました。
というわけで1階ではDDIMと同じなので通常使われません。使われるのは2階や3階です。
3階はよ―わからんけど同じようなことを3回にするんでしょう。
モデルの計算回数や時刻に応じて何階にするかを自動で決定する方式(ComfyUIのdpm-fast)やステップサイズを適応的に変更する方法(dpm-adaptive)も提案されています。
また後述しますがノイズの強さを調整することも考えられています。
dom-solverはCFGを使うと、精度が下がるという弱点があるらしく、それを改善するために予測ノイズではなく予測した元画像を利用する方式がDPM-Solver++です。
最適解自体はdpm-solverと同じですが、どの部分を近似するかというところが違うらしい。CFGを使う場合、ノイズ予測が学習時と生成時で一致しないので、ノイズ予測結果ではなく元の画像の予測結果を使ってるみたいな感じなんかね。
これも2階や3階があります。2階は論文にはより一般化された記述がなされていましたが、ComfyUIの実装をみると中点法を使ってるっぽいのでdpm-solverと同様ということで省略します。ただしこちらは過去のステップの情報を使うことで、1回分の計算で2階近似できるようになるMulti Steps(M)バージョンがあります。
SDE版のソルバーも提案されています。まあ全然わかりませんけど。
各ステップでノイズが加えられるので、決定的な生成になりません。DDIMのをいじるのと似ているのかな。
eulerやdpm-solver等で適用できるやつです。1ステップごとに普通より多めにノイズ除去をして、少しノイズを加えるということをします。これもDDIMのを調整するのと似ているのかな。がだけでなくにも依存するようにします(だから先祖サンプリング?)。これによって安定するけど収束しづらくなると意味不明な説明がされています(収束しづらいのに安定するってなんだよ)。
ステップをスキップするとき、時刻をどう分けるかという工夫の余地があります。細かい方法を除けば3つの方法があります。
以下の記事が詳しいです。
KSampler ノード - ComfyUI 解説 (wiki ではない)
時刻を普通に等分します。おわり。
Karrasさんが考えたスケジューラーです。の7乗根を等分します。normalに比べて中間のステップを飛ばす代わりに最初と最後らへんの精度をあげているようです。webuiでは最後にKarrasがつくサンプラーがこれを採用しています。
[2006.11239] Denoising Diffusion Probabilistic Models
[2010.02502] Denoising Diffusion Implicit Models
[2202.09778] Pseudo Numerical Methods for Diffusion Models on Manifolds
[2206.00927] DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps
[2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models
[2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models
思いつくものをいい加減に貼っていく。
拡散モデル - 岩波書店https://hillbig.github.io/diffusion-models/
コンピュータビジョン最前線 Summer 2023 - 共立出版
改訂版 Stable Diffusionで使われているアルゴリズムの解説 - 過学習ショップ - BOOTH
もし生物情報科学専攻の大学院生が "StableDiffusion" を理解しようとしたら 9 ~Stable Diffusion②~ - 何だって、したしむ
DDPMの関連技術 | henatips
DiffusionによるText2Imageの系譜と生成画像が動き出すまで - Speaker Deck
【論文解説】Diffusion Modelを理解する | 楽しみながら理解するAI・機械学習入門
LoRA学習を効率化できるかもしれない方法を考えたので説明していきます。簡単に言うと従来はLoRAをdown層とup層に分けて二層を順次計算していましたが、down層とup層を合体して、さらに元の重みにマージしてから計算した方が効率が良くなるかもしれませんという話です。この記事では行列積の計算量とかで単純なアルゴリズムを前提にしていますが、実際はもっと最適化しているはずなので実装とのずれはあると思います。あと分からないことがありますが放置中です。
LoRAは以下のように計算できます。
は入出力チャンネル、はそれ以外の次元の積(画像の縦横やバッチサイズ、トークン長等)になります。
この3つの計算量について考えていきます。足し算の数はほぼ無視していいと思うので、掛け算の数だけで比較していきます。
前提として、行列積の計算の掛け算の数はになります。
(1)と(2)を比較すると、は同じとして、それ以外の部分はが掛けられるのはいっしょなので無視すると、との大小比較になります。がに比べて十分大きい場合、(1)の方が効率よくなります。
rankは大小関係には影響しませんが、どのくらいの差があるかはrankに比例して大きくなっていきます。
は画像のサイズ、トークン長、バッチサイズに影響します。これは学習設定やモデルの層によって変わっていきます。
バッチサイズや画像のサイズを大きくすればするほどが大きくなり、(1)式を使った方がよくなっていくと思います。
Stable DiffusionのUNetは内側の層ほど画像が縮小され、チャンネル数が増えます。つまりが減っていきは増えていきます。つまりUNetの外側の層ほど(1)式を使った方がよくなります。ちなみにテキストエンコーダははトークン長77×バッチサイズで固定、はだいたい768 or 1024 or 1280なので、バッチサイズ4~8くらいで逆転しそうですね。
入力に応じて適応的に式を選ぶのが一番よさそうです。はLoRAを作るときに計算できます。は入力に応じて逐一計算する必要があります。
この項はちょっと自信ないです。行列積時の逆伝搬のためのキャッシュはです。足し算も同じです。
の勾配に必要な項は無視しています(キャッシュされないはず?)。共通項を除くとの比較になります。これもbが大きいと(1)式が有利になります。
この項もちょっと自信ないですが、行列積の逆伝搬の計算量は順伝搬の2倍になるだけです。ただし(2)式ではは勾配不要なので、分減ります。あれえ?じゃあ(1)式効率悪くない?の比較はほとんど後者の方が小さくなりそうです。ただこの影響がどうも実験的には出てこないんですよね。よく分かりません。
3×3畳み込みの場合、を9倍すればいいだけです(多分)。畳み込みは画像の1チャンネルを9チャンネルに拡張したような計算になります。LoRAの場合down層の入力チャンネル数が9倍されます。
適当にやりました。Paperspace A6000で768画像をバッチサイズ8でsdv2モデルをrank=16のloconでxformersありで320ステップやりました。
(1)式のみを使う:327秒/27197MiB
(1)式をに基づいて適応:324秒/26749MiB
(1)式をに基づいて適応:318秒/26593MiB
従来通り(2)式:348秒/27427MiB
従来の計算よりも効率的であることが分かりますた。なんかVRAMを一番節約できそうな式が計算時間的にも一番良い結果になりました。まあ最適化とかを無視した計算をしているので理論通りにはならないのは当然といえば当然ですね。
backward側の計算量が(2)式の方が効率よくなりそうな話についてはよく分かりませんでした...。
LoHAとはアダマール積を使ったLoRAの応用手法です。琥珀青葉(KohakuBlueleaf)さんによってStable-diffusionで実装されました。LoHAの実装ではbackwardが定義されています。Pytorchでは特殊な関数を使わない限りbackwardを定義することはありません。しかしLoHAは特殊な関数を使うわけではありません。琥珀青葉さんが簡単に理由を説明してくれているのですが、そもそもbackwardを定義する方法も含めて私はよく知らなかったのでまとめてみます。
元論文
[2108.06098] FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning
琥珀青葉さんの解説
LyCORIS/Algo.md at main · KohakuBlueleaf/LyCORIS · GitHub
Pytorchは、の計算で逆伝搬のためにをキャッシュしてしまいます。しかし実際はをキャッシュしておけば十分なのでそうなるように実装されています。
LoRAは重み行列の差分をに分解します。パラメータは削減できますが、階数はで制限されてしまいます。そこでLoHAはアダマール積を使って、とします。パラメータ数が2倍になっていますが、階数の上界はになります。
詳しくは以前書いた記事を参考にしてください。
LoRAとLoHAの階数を比較する|gcem156
LoRAのは二つの行列積に分解して計算できます。そのためを直接計算する必要はありません。しかしLoHAの場合はは順次計算することはできず、を計算してから、を計算する必要があります。ここで注目しなければならないのはの微分です。
誤差逆伝搬法では連鎖律を用いて出力側の層の微分を入力側の層の微分の計算に用いることで、効率よく微分を行います。PytorchのTensorでは四則演算とか行列の積とかの基本的な関数には微分が定義されていて、それらの合成関数であるニューラルネットワークは誤差逆伝搬法を使って自動的に微分が定義されます。
自作関数は以下のように作れるようです。
class sugoi_function_class(torch.autograd.function): @staticmethod def forward(ctx, x) ctx.save_for_backward(x) y = sugoi_function(x) return y @staticmethod def backward(dL_dy) x, = ctx.saved_tensors dy_dx = sugoi_funcition_no_bibun(x) dL_dx = dL_dy * dy_dx return dL_dx
backwardの定義は、まず損失に対する出力の微分()が入力され、出力に対する入力微分()をかけるという流れです。この際計算に必要な順伝搬時の情報を記憶しておきます。
以下の記事を参考にしました。
【PyTorch】自作関数の勾配計算式(backward関数)の書き方① - Qiita
LyCORISではの計算式及びbackwardがここで実装されています。
実装は以下のようになってます。(こぴぺ)
class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)): ctx.save_for_backward(w1a, w1b, w2a, w2b, scale) diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale return diff_weight @staticmethod def backward(ctx, grad_out): (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors grad_out = grad_out * scale temp = grad_out*(w2a@w2b) grad_w1a = temp @ w1b.T grad_w1b = w1a.T @ temp temp = grad_out * (w1a@w1b) grad_w2a = temp @ w2b.T grad_w2b = w2a.T @ temp del temp return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
順伝搬の出力がであるため、backwardへの入力はです。
行列に注目すると、アダマール積に関する微分は以下の式で表されます。
要素ごとの積なので、連鎖律も行列積ではなく要素ごとの積になります。
行列積に関する微分は以下の式で表されます。
よって、
に関しても同様にして、
となります。
実装において、順伝搬時に保存する行列はになっています。
(scaleについては本質的ではないので省略。)
の微分は、アダマール積と行列積の合成関数になっています。この二つは特殊な関数ではないので、Pytorch上で当然backwardが定義されており、自動微分可能です。それではどうしてbackwardを定義する必要があるのか確認していきましょう。
import torch from torchviz import make_dot m=10 n=5 r=2 A = torch.randn((m,r)) B = torch.randn((r,n)) C = torch.randn((m,r)) D = torch.randn((r,n)) for mat in [A,B,C,D]: mat.requires_grad = True AB = A@B # 後の確認用にABを残しておく CD = C@D W = AB * CD image = make_dot(W) image.format = "png" image.render("test")
シンプルな実装で重みを作ってみました。このときの計算グラフをみてみると以下のような感じになりました。
グラフなんか見なくても想像できますが、4つの行列から2つずつ行列積をとりできた2つの行列でアダマール積をとっています。ここで重要なのはアダマール積(MulBackward0)の部分です。Pytorchの実装ではアダマール積をとるとき逆伝搬に必要な順伝搬の入力である、の二つを自動的にキャッシュしてしまいます。しかしその上の行列積部分ではをキャッシュしているため、はキャッシュせずに計算可能です。実際LyCORISの実装ではそうしていますね。ちゃんとキャッシュされているかは以下の実装で確かめられます。
print(W.grad_fn._saved_self is AB) print(W.grad_fn._saved_other is CD) print(W.grad_fn.next_functions[0][0]._saved_self is A) print(W.grad_fn.next_functions[0][0]._saved_mat2 is B) print(W.grad_fn.next_functions[1][0]._saved_self is C) print(W.grad_fn.next_functions[1][0]._saved_mat2 is D)
LoHAは行列分解する手法なので、に対してのサイズが非常に大きくなります(それぞれ学習対象パラメータと同じサイズになる)。そのためをキャッシュしてしまうのはてりぶるなしちゅえーしょんです。そのためbackwardを定義しているんですね。
sdv2のunetのみで768画像、バッチサイズ1、xformersあり、bfloat16でrank=16のconvを含むLoHAでVRAM使用量を確認してみました。
元の実装を傷つけずに比較するため、以下のようなclassを作ってみました。
class HadaWeight: @staticmethod def apply(orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)): diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale return orig_weight.reshape(diff_weight.shape) + diff_weight
結果ですが、元実装の9.3GBから12.3GBにまでVRAM使用量が増加しました。理論的には学習対象パラメータの2倍になります。UNetの重みは3.4GBですが、bfloat16なのとLoHAの学習対象に含まれないパラメータもあることを考えると3GBの増加というのは、だいたいあってそうですね。ちなみに計算時間の方はほとんど変わらないようです。キャッシュをやめたところでしょぼい行列計算が増えるだけなのであまり変わらないんでしょうね。ついでにgradient_checkpointingを有効にして比較してみましたが、VRAM使用量は全く変わりませんでした。これはgradient_checkpointingが順伝搬でのキャッシュを省略する手法なので、自然な結果ですね。
おわり。
SD XLの実装において、損失計算のアルゴリズムが異なり困惑したので、確認してみます。最終的には今までの損失と同値になることが分かりました。
時刻の潜在変数を、ノイズを、UNetをとします。diffusion modelでは、拡散過程tステップは以下のようにあらわされます。
ここではスケジューラが持っているハイパーパラメータであり、SDv1, SDv2, SD XLで数値は多分変わりません。(deepfloyd ifは違ったりする)
損失は、UNetがノイズを予測するよう学習させるため、ノイズと予測ノイズの二乗誤差になります。
(※sdv2-768系はちょっと違います。)
diffusersによる訓練を超簡易的に書くと、
for x_0, c in dataset: # 実画像や条件付けをサンプリング t = torch.randint(0,T) noise = torch.randn_like(x_0) x_t = scheduler.add_noise(x_0, t) noise_pred = unet(x_t, t, c) loss = mse(noise, noise_pred) ...
となりますが、それに対してSD XLでの損失計算は以下のような流れになっているっぽいです。
for x_0, c in dataset: # 実画像や条件付けをサンプリング t = torch.randint(0,T) noise = torch.randn_like(x_0) sigma = scheduler.get_sigmas(t) x_t_pre = x_0 + noise * sigma x_t /= (sigma ** 2 + 1) ** 0.5 noise_pred = unet(x_t, t, c) x_0_pred = x_t_pre - noise_pred * sigma loss = mse(x_0, x_0_pred) / (sigma ** 2) ...
違いとして、sigmaという新しい値がでてきたり、最後のlossがノイズではなく元の潜在変数の二乗誤差になっています。
sigmaは以下の数式で表されます。
ここで、とは信号対雑音比のことで、信号の分散をノイズの分散で割ったものです。
それではまず両コードのx_tが同じであることを確認します。
SD XL側の実装において、
であり、次に(sigma ** 2 + 1) ** 0.5ですが、これは数式にすると、
となります。
これが割られるので、x_0に掛けられる係数は、noiseに掛けられる係数は
となり、ちゃんとと同値になることが分かります。unetがノイズを予測できていれば、x_0_predは元のx_0に近づきます。損失はその二つの平均二乗誤差です。ただし損失にが掛けられています。これは以前書いた記事にちょうど関連する話であり、という関係から最終的な損失はノイズの二乗誤差になることが分かります。
何でこんな回りくどいやり方にしたの?と思われるかもしれませんが、どうやらEulerなどの微分方程式ソルバー系の実装に合わせているようですね。まあこの辺り私はよく分かってません。