勾配降下党青年局

万国のグラーディエントよ、降下せよ!

LoHAでbackwardを定義する理由

 LoHAとはアダマール積を使ったLoRAの応用手法です。琥珀青葉(KohakuBlueleaf)さんによってStable-diffusionで実装されました。LoHAの実装ではbackwardが定義されています。Pytorchでは特殊な関数を使わない限りbackwardを定義することはありません。しかしLoHAは特殊な関数を使うわけではありません。琥珀青葉さんが簡単に理由を説明してくれているのですが、そもそもbackwardを定義する方法も含めて私はよく知らなかったのでまとめてみます。
元論文
[2108.06098] FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning
琥珀青葉さんの解説
LyCORIS/Algo.md at main · KohakuBlueleaf/LyCORIS · GitHub

長いので結論を先に書く

 Pytorchは、AB\odot CDの計算で逆伝搬のためにAB, CDをキャッシュしてしまいます。しかし実際はA,B,C,Dをキャッシュしておけば十分なのでそうなるように実装されています。

LoHAについて

 LoRAは重み行列の差分\Delta W^{(m,n)}A^{(m,r)}B^{(r,n)}に分解します。パラメータは削減できますが、階数はrで制限されてしまいます。そこでLoHAはアダマール\odotを使って、\Delta W^{(m,n)}=A^{(m,r)}B^{(r,n)}\odot C^{(m,r)}D^{(r,n)}とします。パラメータ数が2倍になっていますが、階数の上界はr^2になります。
詳しくは以前書いた記事を参考にしてください。
LoRAとLoHAの階数を比較する|gcem156

LoHAの計算

 LoRAのABxは二つの行列積に分解して計算できます。そのため\Delta Wを直接計算する必要はありません。しかしLoHAの場合は(AB\odot CD)xは順次計算することはできず、\Delta Wを計算してから、\Delta Wxを計算する必要があります。ここで注目しなければならないのは\Delta W微分です。

Pytorchの自動微分

 誤差逆伝搬法では連鎖律を用いて出力側の層の微分を入力側の層の微分の計算に用いることで、効率よく微分を行います。PytorchのTensorでは四則演算とか行列の積とかの基本的な関数には微分が定義されていて、それらの合成関数であるニューラルネットワークは誤差逆伝搬法を使って自動的に微分が定義されます。
自作関数は以下のように作れるようです。

class sugoi_function_class(torch.autograd.function):
    @staticmethod
    def forward(ctx, x)
        ctx.save_for_backward(x)
        y = sugoi_function(x)
        return y

    @staticmethod
    def backward(dL_dy)
        x, = ctx.saved_tensors
        dy_dx = sugoi_funcition_no_bibun(x)
        dL_dx = dL_dy * dy_dx
        return dL_dx 

 backwardの定義は、まず損失に対する出力の微分(\frac{dL}{dy})が入力され、出力に対する入力微分(\frac{dy}{dx})をかけるという流れです。この際計算に必要な順伝搬時の情報xを記憶しておきます。

以下の記事を参考にしました。
【PyTorch】自作関数の勾配計算式(backward関数)の書き方① - Qiita

LoHAの微分

 LyCORISでは\Delta Wの計算式及びbackwardがここで実装されています。
実装は以下のようになってます。(こぴぺ)

class HadaWeight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
        ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
        diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
        return diff_weight

    @staticmethod
    def backward(ctx, grad_out):
        (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
        grad_out = grad_out * scale
        temp = grad_out*(w2a@w2b)
        grad_w1a = temp @ w1b.T
        grad_w1b = w1a.T @ temp

        temp = grad_out * (w1a@w1b)
        grad_w2a = temp @ w2b.T
        grad_w2b = w2a.T @ temp
        
        del temp
        return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None

 順伝搬の出力がAB\odot CDであるため、backwardへの入力は\frac{dL}{d(AB\odot CD)}です。
行列ABに注目すると、アダマール積に関する微分は以下の式で表されます。
 \frac{dL}{d(AB)}=\frac{dL}{d(AB\odot CD)}\odot\frac{d(AB\odot CD)}{dAB} = \frac{dL}{d(AB\odot CD)} \odot CD
要素ごとの積なので、連鎖律も行列積ではなく要素ごとの積になります。
 行列積に関する微分は以下の式で表されます。
 \frac{dL}{dA} =  \frac{dL}{dAB} B^T,\frac{dL}{dB} = A^T \frac{dL}{dAB}

よって、
\frac{dL}{dA} = (\frac{dL}{d(AB\odot CD)} \odot CD) B^T
\frac{dL}{dB} = A^T (\frac{dL}{d(AB\odot CD)} \odot CD)
C,Dに関しても同様にして、
\frac{dL}{dC} = (\frac{dL}{d(AB\odot CD)} \odot AB) D^T
\frac{dL}{dD} = C^T (\frac{dL}{d(AB\odot CD)} \odot AB)
となります。

 実装において、順伝搬時に保存する行列はA,B,C,Dになっています。
(scaleについては本質的ではないので省略。)

Pytorchの実装をそのまま使う場合

 AB\odot CD微分は、アダマール積と行列積の合成関数になっています。この二つは特殊な関数ではないので、Pytorch上で当然backwardが定義されており、自動微分可能です。それではどうしてbackwardを定義する必要があるのか確認していきましょう。

import torch
from torchviz import make_dot
m=10
n=5
r=2

A = torch.randn((m,r))
B = torch.randn((r,n))
C = torch.randn((m,r))
D = torch.randn((r,n))
for mat in [A,B,C,D]:
    mat.requires_grad = True

AB = A@B # 後の確認用にABを残しておく
CD = C@D
W = AB * CD

image = make_dot(W)
image.format = "png"
image.render("test")

シンプルな実装で重みを作ってみました。このときの計算グラフをみてみると以下のような感じになりました。

 グラフなんか見なくても想像できますが、4つの行列から2つずつ行列積をとりできた2つの行列でアダマール積をとっています。ここで重要なのはアダマール積(MulBackward0)の部分です。Pytorchの実装ではアダマール積をとるとき逆伝搬に必要な順伝搬の入力である、AB, CDの二つを自動的にキャッシュしてしまいます。しかしその上の行列積部分ではA,B,C,Dをキャッシュしているため、AB, CDはキャッシュせずに計算可能です。実際LyCORISの実装ではそうしていますね。ちゃんとキャッシュされているかは以下の実装で確かめられます。

print(W.grad_fn._saved_self is AB)
print(W.grad_fn._saved_other is CD)
print(W.grad_fn.next_functions[0][0]._saved_self is A)
print(W.grad_fn.next_functions[0][0]._saved_mat2 is B)
print(W.grad_fn.next_functions[1][0]._saved_self is C)
print(W.grad_fn.next_functions[1][0]._saved_mat2 is D)

LoHAは行列分解する手法なので、A,B,C,Dに対してAB,CDのサイズが非常に大きくなります(それぞれ学習対象パラメータと同じサイズになる)。そのためAB,CDをキャッシュしてしまうのはてりぶるなしちゅえーしょんです。そのためbackwardを定義しているんですね。

実装を比較してみる

 sdv2のunetのみで768画像、バッチサイズ1、xformersあり、bfloat16でrank=16のconvを含むLoHAでVRAM使用量を確認してみました。

元の実装を傷つけずに比較するため、以下のようなclassを作ってみました。

class HadaWeight:
    @staticmethod
    def apply(orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
        diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
        return orig_weight.reshape(diff_weight.shape) + diff_weight

 結果ですが、元実装の9.3GBから12.3GBにまでVRAM使用量が増加しました。理論的には学習対象パラメータの2倍になります。UNetの重みは3.4GBですが、bfloat16なのとLoHAの学習対象に含まれないパラメータもあることを考えると3GBの増加というのは、だいたいあってそうですね。ちなみに計算時間の方はほとんど変わらないようです。キャッシュをやめたところでしょぼい行列計算が増えるだけなのであまり変わらないんでしょうね。ついでにgradient_checkpointingを有効にして比較してみましたが、VRAM使用量は全く変わりませんでした。これはgradient_checkpointingが順伝搬でのキャッシュを省略する手法なので、自然な結果ですね。

おわり。