LoRA学習を効率化できるかもしれない方法を考えたので説明していきます。簡単に言うと従来はLoRAをdown層とup層に分けて二層を順次計算していましたが、down層とup層を合体して、さらに元の重みにマージしてから計算した方が効率が良くなるかもしれませんという話です。この記事では行列積の計算量とかで単純なアルゴリズムを前提にしていますが、実際はもっと最適化しているはずなので実装とのずれはあると思います。あと分からないことがありますが放置中です。
計算量の考察
LoRAは以下のように計算できます。
は入出力チャンネル、はそれ以外の次元の積(画像の縦横やバッチサイズ、トークン長等)になります。
- (0)式は生成時に使われる場合があります。あらかじめLoRAをモデルにマージしておくことで、計算量増加を防ぎます。
- (1)式はLoHAの学習で使われます。LoRAの重みを計算して元の重みにマージしてから計算します。
- (2)式は通常のLoRA学習時に使われます。元の重みで計算した出力と、入力にdown層、up層の順で通した出力を足し合わせます。
この3つの計算量について考えていきます。足し算の数はほぼ無視していいと思うので、掛け算の数だけで比較していきます。
前提として、行列積の計算の掛け算の数はになります。
- (0)式:です。
- (1)式:部分の計算がです。あとは(0)と同じような計算であり、が足されます。
- (2)式:部分の計算がで、同じくが足されます。
(1)と(2)を比較すると、は同じとして、それ以外の部分はが掛けられるのはいっしょなので無視すると、との大小比較になります。がに比べて十分大きい場合、(1)の方が効率よくなります。
rankは大小関係には影響しませんが、どのくらいの差があるかはrankに比例して大きくなっていきます。
の大きさについて
は画像のサイズ、トークン長、バッチサイズに影響します。これは学習設定やモデルの層によって変わっていきます。
- 学習設定について
バッチサイズや画像のサイズを大きくすればするほどが大きくなり、(1)式を使った方がよくなっていくと思います。
- モデルの設計について
Stable DiffusionのUNetは内側の層ほど画像が縮小され、チャンネル数が増えます。つまりが減っていきは増えていきます。つまりUNetの外側の層ほど(1)式を使った方がよくなります。ちなみにテキストエンコーダははトークン長77×バッチサイズで固定、はだいたい768 or 1024 or 1280なので、バッチサイズ4~8くらいで逆転しそうですね。
入力に応じて適応的に式を選ぶのが一番よさそうです。はLoRAを作るときに計算できます。は入力に応じて逐一計算する必要があります。
VRAM使用量について
この項はちょっと自信ないです。行列積時の逆伝搬のためのキャッシュはです。足し算も同じです。
- (1)式は
- (2)式は
の勾配に必要な項は無視しています(キャッシュされないはず?)。共通項を除くとの比較になります。これもbが大きいと(1)式が有利になります。
backward側の計算について
この項もちょっと自信ないですが、行列積の逆伝搬の計算量は順伝搬の2倍になるだけです。ただし(2)式ではは勾配不要なので、分減ります。あれえ?じゃあ(1)式効率悪くない?の比較はほとんど後者の方が小さくなりそうです。ただこの影響がどうも実験的には出てこないんですよね。よく分かりません。
3×3の畳み込み層について
3×3畳み込みの場合、を9倍すればいいだけです(多分)。畳み込みは画像の1チャンネルを9チャンネルに拡張したような計算になります。LoRAの場合down層の入力チャンネル数が9倍されます。
実験
適当にやりました。Paperspace A6000で768画像をバッチサイズ8でsdv2モデルをrank=16のloconでxformersありで320ステップやりました。
(1)式のみを使う:327秒/27197MiB
(1)式をに基づいて適応:324秒/26749MiB
(1)式をに基づいて適応:318秒/26593MiB
従来通り(2)式:348秒/27427MiB
従来の計算よりも効率的であることが分かりますた。なんかVRAMを一番節約できそうな式が計算時間的にも一番良い結果になりました。まあ最適化とかを無視した計算をしているので理論通りにはならないのは当然といえば当然ですね。
backward側の計算量が(2)式の方が効率よくなりそうな話についてはよく分かりませんでした...。