LoHAでbackwardを定義する理由
LoHAとはアダマール積を使ったLoRAの応用手法です。琥珀青葉(KohakuBlueleaf)さんによってStable-diffusionで実装されました。LoHAの実装ではbackwardが定義されています。Pytorchでは特殊な関数を使わない限りbackwardを定義することはありません。しかしLoHAは特殊な関数を使うわけではありません。琥珀青葉さんが簡単に理由を説明してくれているのですが、そもそもbackwardを定義する方法も含めて私はよく知らなかったのでまとめてみます。
元論文
[2108.06098] FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning
琥珀青葉さんの解説
LyCORIS/Algo.md at main · KohakuBlueleaf/LyCORIS · GitHub
長いので結論を先に書く
Pytorchは、の計算で逆伝搬のためにをキャッシュしてしまいます。しかし実際はをキャッシュしておけば十分なのでそうなるように実装されています。
LoHAについて
LoRAは重み行列の差分をに分解します。パラメータは削減できますが、階数はで制限されてしまいます。そこでLoHAはアダマール積を使って、とします。パラメータ数が2倍になっていますが、階数の上界はになります。
詳しくは以前書いた記事を参考にしてください。
LoRAとLoHAの階数を比較する|gcem156
LoHAの計算
LoRAのは二つの行列積に分解して計算できます。そのためを直接計算する必要はありません。しかしLoHAの場合はは順次計算することはできず、を計算してから、を計算する必要があります。ここで注目しなければならないのはの微分です。
Pytorchの自動微分
誤差逆伝搬法では連鎖律を用いて出力側の層の微分を入力側の層の微分の計算に用いることで、効率よく微分を行います。PytorchのTensorでは四則演算とか行列の積とかの基本的な関数には微分が定義されていて、それらの合成関数であるニューラルネットワークは誤差逆伝搬法を使って自動的に微分が定義されます。
自作関数は以下のように作れるようです。
class sugoi_function_class(torch.autograd.function): @staticmethod def forward(ctx, x) ctx.save_for_backward(x) y = sugoi_function(x) return y @staticmethod def backward(dL_dy) x, = ctx.saved_tensors dy_dx = sugoi_funcition_no_bibun(x) dL_dx = dL_dy * dy_dx return dL_dx
backwardの定義は、まず損失に対する出力の微分()が入力され、出力に対する入力微分()をかけるという流れです。この際計算に必要な順伝搬時の情報を記憶しておきます。
以下の記事を参考にしました。
【PyTorch】自作関数の勾配計算式(backward関数)の書き方① - Qiita
LoHAの微分
LyCORISではの計算式及びbackwardがここで実装されています。
実装は以下のようになってます。(こぴぺ)
class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)): ctx.save_for_backward(w1a, w1b, w2a, w2b, scale) diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale return diff_weight @staticmethod def backward(ctx, grad_out): (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors grad_out = grad_out * scale temp = grad_out*(w2a@w2b) grad_w1a = temp @ w1b.T grad_w1b = w1a.T @ temp temp = grad_out * (w1a@w1b) grad_w2a = temp @ w2b.T grad_w2b = w2a.T @ temp del temp return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
順伝搬の出力がであるため、backwardへの入力はです。
行列に注目すると、アダマール積に関する微分は以下の式で表されます。
要素ごとの積なので、連鎖律も行列積ではなく要素ごとの積になります。
行列積に関する微分は以下の式で表されます。
よって、
に関しても同様にして、
となります。
実装において、順伝搬時に保存する行列はになっています。
(scaleについては本質的ではないので省略。)
Pytorchの実装をそのまま使う場合
の微分は、アダマール積と行列積の合成関数になっています。この二つは特殊な関数ではないので、Pytorch上で当然backwardが定義されており、自動微分可能です。それではどうしてbackwardを定義する必要があるのか確認していきましょう。
import torch from torchviz import make_dot m=10 n=5 r=2 A = torch.randn((m,r)) B = torch.randn((r,n)) C = torch.randn((m,r)) D = torch.randn((r,n)) for mat in [A,B,C,D]: mat.requires_grad = True AB = A@B # 後の確認用にABを残しておく CD = C@D W = AB * CD image = make_dot(W) image.format = "png" image.render("test")
シンプルな実装で重みを作ってみました。このときの計算グラフをみてみると以下のような感じになりました。
グラフなんか見なくても想像できますが、4つの行列から2つずつ行列積をとりできた2つの行列でアダマール積をとっています。ここで重要なのはアダマール積(MulBackward0)の部分です。Pytorchの実装ではアダマール積をとるとき逆伝搬に必要な順伝搬の入力である、の二つを自動的にキャッシュしてしまいます。しかしその上の行列積部分ではをキャッシュしているため、はキャッシュせずに計算可能です。実際LyCORISの実装ではそうしていますね。ちゃんとキャッシュされているかは以下の実装で確かめられます。
print(W.grad_fn._saved_self is AB) print(W.grad_fn._saved_other is CD) print(W.grad_fn.next_functions[0][0]._saved_self is A) print(W.grad_fn.next_functions[0][0]._saved_mat2 is B) print(W.grad_fn.next_functions[1][0]._saved_self is C) print(W.grad_fn.next_functions[1][0]._saved_mat2 is D)
LoHAは行列分解する手法なので、に対してのサイズが非常に大きくなります(それぞれ学習対象パラメータと同じサイズになる)。そのためをキャッシュしてしまうのはてりぶるなしちゅえーしょんです。そのためbackwardを定義しているんですね。
実装を比較してみる
sdv2のunetのみで768画像、バッチサイズ1、xformersあり、bfloat16でrank=16のconvを含むLoHAでVRAM使用量を確認してみました。
元の実装を傷つけずに比較するため、以下のようなclassを作ってみました。
class HadaWeight: @staticmethod def apply(orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)): diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale return orig_weight.reshape(diff_weight.shape) + diff_weight
結果ですが、元実装の9.3GBから12.3GBにまでVRAM使用量が増加しました。理論的には学習対象パラメータの2倍になります。UNetの重みは3.4GBですが、bfloat16なのとLoHAの学習対象に含まれないパラメータもあることを考えると3GBの増加というのは、だいたいあってそうですね。ちなみに計算時間の方はほとんど変わらないようです。キャッシュをやめたところでしょぼい行列計算が増えるだけなのであまり変わらないんでしょうね。ついでにgradient_checkpointingを有効にして比較してみましたが、VRAM使用量は全く変わりませんでした。これはgradient_checkpointingが順伝搬でのキャッシュを省略する手法なので、自然な結果ですね。
おわり。