某大学生の機械学習日記

趣味はでぃーぷらーにんぐ

DNNアンサンブル学習手法"Stochastic Weight Averaging(SWA)"で遊んでみた

動機

  • 最近学校の課題が忙しく、ガッツリ系の実装ができていない中ちょうど良さげなものが見つかったから。

  • これ本当か?と感覚的には理解できなかったので実際に試してみて調べたかったから

Stochastic Weight Averaging(SWA)の概要

従来のアンサンブル学習ではモデル・アンサンブルだったのに対して、SWAではウェイト・アンサンブルをとった。つまり、複数のモデルの重みの平均を取り込んだモデルで結果を予測する。感覚的に俄には信じられないが、Resnet150大型のモデルで効力を発揮しているようだ。

従来のニュールモデルのアンサンブル学習の二つの手法
  1. Snapshot Ensembling
    Learning Rate Schedulingのサイクルが終わるごとにその時のモデルの"weight"を保存する。そのサイクルをいくつか繰り返し、保存したweightをもつモデルのアンサンブルによって推定を行う手法。
  2. Fast Geometric Ensembling (FGE) Snapshot Ensemblingと似ているが、異なる点が二つある。一つ目は、サイクルの周期が数エポックごとにしていて計算時間の短縮をした。二つ目に"inear piecewise cyclical learning rate schedule"を起用している(Snapshotではcosine)。ちょっとした違いだが、Snapshotのより精度が高くなったらしい。
SWA

前述の手法の欠点はweightを保つ必要があったためメモリを食ってしまうところ。SWAではモデル二つぶんのweightsのみでアンサンブル学習()を実現し、かつ前述の良いところを引き継いだ構造となっている。
二のつモデルの役割:
* 一つ目のモデルは、cyclical learning scheduleに基づいてベストなweightを探索するもの(式中のw)
* 二つめのモデルは、1サイクルごとに求められたweightを蓄えるモデル(式中のw_swa)
以下の式が具体的な更新方法 f:id:knto-h:20180516131052p:plain:w400

実装

モデルの定義
今回は小さいデータセットでどんな挙動を示すかが気になったので、fashion-mnistを用いることにした。それにさいしモデルを以下のように(半ば適当に)定義。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1,16,kernel_size=4,stride=2) # 12x12
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16,32,kernel_size=4, stride=2) # 5x5
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        B,_,_,_ = x.size()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = x.view(B, -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        return out

SWA関数の定義
二つのモデル(model_swa, model)のパラメータ処理を定義。

def SWA(model_swa, model, n):
    _n = n
    for w_swa, w in zip(model_swa.parameters(), model.parameters()):
        w_swa.data *= _n/(_n+1)
        w_swa.data += w.data/(_n+1)

メイン関数とハイパーパラメータ
main()にてループを回す。一回のサイクルあたり4回エポックを回し、それを10回繰り返した。つまり、10回model_swaのパラメータを更新した。

def main():
    model = Net().to(device)
    model_swa = Net().to(device)

    ...
    ...

    epochs = 4
    cycles = 10
    loss_data = []
    for cycle in range(cycles):
        print("-------  Cycle {0}  -------".format(cycle+1))
        optimizer = optim.SGD(model.parameters(), lr=0.1)
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
        for epuch in range(epochs):
            scheduler.step()
            train(model, criterion,optimizer,train_loader)
            average_accu = validate(model, criterion, val_loader)
        SWA(model_swa,model, cycle)
        validate(model_swa, criterion, val_loader)

結果

  • 普通のモデルでの精度は大体99.8%(train), 99.1%(test)ぐらいをうろちょろしてた。

  • SWAモデルではサイクルごとに[86.41, 85.39, 84.58, 83.52, 83.77, 84.09, 84.24, 84.62, 84.52, 84.58]とかだった。。。

  • あまりに大雑把にやりすぎたせいか結果がでなかったから時間があるときにもっと研究してみたいなー。

  • 比較的小型のモデルでは精度の向上はみられないのか(?)

参考

原論文 → https://arxiv.org/pdf/1803.05407.pdf
記事 → Stochastic Weight Averaging — a New Way to Get State of the Art Results in Deep Learning