ホリケン's diary

趣味はでぃーぷらーにんぐ

Fast AutoAugment再現実装(PyTorch)

Fast AutoAugmentとは

すごい簡単ではあるがこのスライドを参考にすると大体の"キモチ"が伝わるはずだ。

実装

  • 公式実装(https://github.com/kakaobrain/fast-autoaugment)は公開されているのだがなぜか肝心なpolicy searchの部分が公開されていない...
  • policy search => trainまでフルパッケージで自分でコードを用意することにした。実装は僕のgithubにあげておいたが、肝となった部分を下に紹介する
  • 論文で言及されているように、この実装ではRayライブラリ(https://github.com/ray-project/ray)を使う。HyperOptのパッケージを使って何か書きたい人とかも参考にしてくれればと思う。

Rayとは

分散処理を簡単にかつシンプルに実行するためのフレームワークである; Tune(Hyperparameter Optimization Framework), RLlib(Scalable Reinforcement Learning)の2つのライブラリがあり、今回はTuneの中のHyperOptのパッケージを使用した。

Trainableクラスの定義

Ray-Tuneでは、Trainableクラスがtrialの度に呼び出される。HyperOptのパッケージでは呼び出されるごとにconfigの中のsearchしているパラメータが更新されていくため、そのパラメータを引きげる用にsetup()を書いた。dataloaderを毎回setup()で読み出してくると実行時間がだいぶかかるため、configで引き継げるようにした。

from ray.tune import Trainable

class TrainCIFAR(Trainable):
    def _setup(self, config):
        args = config.pop("args")
        args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.val_loader = config.pop("dataloader")
        augs = [[config["aug1"], config["p1"], config["value1"]],
                [config["aug2"], config["p2"], config["value2"]]]
        self.val_loader.dataset.transform.transforms.insert(0, Augmentation(augs))
        self.model = config["model"].to(args.device)

        self.criterion = nn.CrossEntropyLoss().to(args.device)
        self.args = args

    def _train(self):
        return self._test()

    def _test(self):
        val_loss = 0
        total = 0
        correct = 0
        self.model.eval()
        with torch.no_grad():
            for i, (inputs, targets) in enumerate(self.val_loader):
                inputs = inputs.to(self.args.device)
                targets = targets.to(self.args.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        return {"mean_accuracy": 100.*correct/total, "mean_loss": val_loss/(i+1)}

    def _save(self, checkpoint_dir):
        return

    def _restore(self, checkpoint):
        return



mainループの定義

以下のようにHypeoOptSearchやconfigを定義することでベイズ最適化の要領でdata augのベストパラメータを探索させる。"config_per_trial"の部分でargs.gpu*gpu-memory=2500mbぐらいになるような値を設定すると一番効率的に分散処理を行うことができる。
ここでやっていることは、get_candidate_augmentによってsub-policyをサンプリングしてきて、hyperoptによってargs.samples回パラメータサーチを行って最適な値を探してくる、というのをargs.B(search depth)分だけ行うという風になっている。ここの実装に置ける工夫は学習済みモデル & dataloaderをconfigの中に渡してしまうことで読み込みの際の時間ロスをなくす点である。

# main()の中の一部

for b in range(args.B):
        ray.init()

        augs = get_candidate_augment(T=args.T)
        while augs in augs_hist:
            augs = get_candidate_augment(T=args.T)
        augs_hist.append(augs)


        config = {
            "resources_per_trial": {
                "cpu": args.cpu,
                "gpu": args.gpu
            },
            "num_samples": args.num_samples,
            "config": {
                "args": args,
                "model": model,
                "dataloader": val_loader,
                "iteration": 1,
            },
            "stop": {
                "training_iteration": 1
            }
        }
        space = {}
        for i in range(1, args.T + 1):
            space[f"p{i}"] = hp.choice(f"p{i}", [i / 10 for i in range(11)])
            space[f"value{i}"] = hp.choice(f"value{i}", [i / 10 for i in range(1, 11)])
            config["config"][f"aug{i}"] = augs[i-1][0].__name__

        algo = HyperOptSearch(
            space,
            max_concurrent=args.max_concurrent,
            metric="mean_accuracy",
            mode="max")
        scheduler = AsyncHyperBandScheduler(time_attr="training_iteration", metric="mean_accuracy", mode="max")
        tune.run(TrainCIFAR,
                 search_alg=algo,
                 scheduler=scheduler,
                 loggers=[JsonLogger],
                 verbose=args.verbose,
                 **config)

        ray.shutdown()

その他の実装

論文に記載されている内容で、AutoAugmentとの比較という観点や探索空間が小さすぎだろと個人的に納得しかねる部分があったのでそこを自分なりに改良して実装しました。さらに著者公開のpolicyでは最終的なsub-policyのセットが493とめちゃ多かったのでそれをAutoAugmentと合わせるようにと25にしました。それ以外で具体的にどの部分を指すかは探してみてください。

再現結果

()の中は著者が公開した、cifar10に対するpolicyを僕のコードで再現した結果。著者のpolicyの精度を超えることはできたが、AutoAugmentの結果には及ばなかった。

Model BaseLine Cutout AutoAugment Fast AutoAugment
Wide-ResNet-40-2 94.54 95.37 95.72 95.46 (95.37)

f:id:knto-h:20190722025340p:plain