勾配降下党青年局

万国のグラーディエントよ、降下せよ!

LoRA学習の効率化法?

 LoRA学習を効率化できるかもしれない方法を考えたので説明していきます。簡単に言うと従来はLoRAをdown層とup層に分けて二層を順次計算していましたが、down層とup層を合体して、さらに元の重みにマージしてから計算した方が効率が良くなるかもしれませんという話です。この記事では行列積の計算量とかで単純なアルゴリズムを前提にしていますが、実際はもっと最適化しているはずなので実装とのずれはあると思います。あと分からないことがありますが放置中です。

計算量の考察

LoRAは以下のように計算できます。
 
\begin{eqnarray}
y^{(b,m)} &=& x^{(b,n)}W_{\mathrm{merge}}^{(n,m)} && \cdots (0) \\
&=& x^{(b,n)}(W_{\mathrm{org}}^{(n,m)}+W_{\mathrm{down}}^{(n,r)}W_{\mathrm{up}}^{(r,m)}) && \cdots (1) \\
&=& x^{(b,n)}W_{\mathrm{org}}^{(n,m)} + (x^{(b,n)}W_{\mathrm{down}}^{(n,r)})W_{\mathrm{up}}^{(r,m)} && \cdots (2)
\end{eqnarray}

(n,m)は入出力チャンネル、bはそれ以外の次元の積(画像の縦横やバッチサイズ、トークン長等)になります。

  • (0)式は生成時に使われる場合があります。あらかじめLoRAをモデルにマージしておくことで、計算量増加を防ぎます。
  • (1)式はLoHAの学習で使われます。LoRAの重みを計算して元の重みにマージしてから計算します。
  • (2)式は通常のLoRA学習時に使われます。元の重みで計算した出力と、入力にdown層、up層の順で通した出力を足し合わせます。

この3つの計算量について考えていきます。足し算の数はほぼ無視していいと思うので、掛け算の数だけで比較していきます。
前提として、行列積X^{(a,b)}Y^{(b,c)}の計算の掛け算の数はabcになります。

  • (0)式:bnmです。
  • (1)式:W_{\mathrm{down}}^{(n,r)}W_{\mathrm{up}}^{(r,m)}部分の計算がnrmです。あとは(0)と同じような計算であり、bnmが足されます。
  • (2)式:(x^{(b,n)}W_{\mathrm{down}}^{(n,r)})W_{\mathrm{up}}^{(r,m)}部分の計算がbnr+brm=br(n+m)で、同じくbnmが足されます。


(1)と(2)を比較すると、bnmは同じとして、それ以外の部分はrが掛けられるのはいっしょなので無視すると、nmb(n+m)の大小比較になります。bn,mに比べて十分大きい場合、(1)の方が効率よくなります。
rankは大小関係には影響しませんが、どのくらいの差があるかはrankに比例して大きくなっていきます。

bの大きさについて

bは画像のサイズ、トークン長、バッチサイズに影響します。これは学習設定やモデルの層によって変わっていきます。

  • 学習設定について

バッチサイズや画像のサイズを大きくすればするほどbが大きくなり、(1)式を使った方がよくなっていくと思います。

  • モデルの設計について

Stable DiffusionのUNetは内側の層ほど画像が縮小され、チャンネル数が増えます。つまりbが減っていきn,mは増えていきます。つまりUNetの外側の層ほど(1)式を使った方がよくなります。ちなみにテキストエンコーダはbトークン長77×バッチサイズで固定、(n,m)はだいたい768 or 1024 or 1280なので、バッチサイズ4~8くらいで逆転しそうですね。
入力に応じて適応的に式を選ぶのが一番よさそうです。nmはLoRAを作るときに計算できます。bは入力に応じて逐一計算する必要があります。

VRAM使用量について

この項はちょっと自信ないです。行列積X^{(a,b)}Y^{(b,c)}時の逆伝搬のためのキャッシュはab+bcです。足し算も同じです。

  • (1)式はnb+nm+rn+rm + nm
  • (2)式はnm + nb+nr + rb+rm + 2bm

W_{\mathrm{org}}の勾配に必要な項は無視しています(キャッシュされないはず?)。共通項を除くとnm, rb+2bmの比較になります。これもbが大きいと(1)式が有利になります。

backward側の計算について

この項もちょっと自信ないですが、行列積の逆伝搬の計算量は順伝搬の2倍になるだけです。ただし(2)式ではW_{\mathrm{org}}は勾配不要なので、bnm分減ります。あれえ?じゃあ(1)式効率悪くない?nm(b+2), 2br(n+m)の比較はほとんど後者の方が小さくなりそうです。ただこの影響がどうも実験的には出てこないんですよね。よく分かりません。

3×3の畳み込み層について

3×3畳み込みの場合、nを9倍すればいいだけです(多分)。畳み込みは画像の1チャンネルを9チャンネルに拡張したような計算になります。LoRAの場合down層の入力チャンネル数が9倍されます。

実験

適当にやりました。Paperspace A6000で768画像をバッチサイズ8でsdv2モデルをrank=16のloconでxformersありで320ステップやりました。
(1)式のみを使う:327秒/27197MiB
(1)式をnm \le b(n+m)に基づいて適応:324秒/26749MiB
(1)式をnm \le rb+2bmに基づいて適応:318秒/26593MiB
従来通り(2)式:348秒/27427MiB

従来の計算よりも効率的であることが分かりますた。なんかVRAMを一番節約できそうな式が計算時間的にも一番良い結果になりました。まあ最適化とかを無視した計算をしているので理論通りにはならないのは当然といえば当然ですね。
backward側の計算量が(2)式の方が効率よくなりそうな話についてはよく分かりませんでした...。