動機
「TSG Solt Identification(kaggle)」のDiscussionでこのLoss関数が話題に上がっていて ちょうど良い機会だったから
目的とする指標ごとに適切なLoss関数が異なることに面白みを覚えたから
事前知識
segmentation
segmentationとは画像系のタスクの一種で、画像の中から対象領域を色塗りするものである。 例えばこのように左列の写真をいれたときに、犬や人の部分を画素レベルで把握している。 これがsegmentation。
Jaccard index (intersection-over-union score, IoU)
Jaccard係数とは以下の式で定義される指標である。積集合を和集合で割ったものである。 簡単のため2クラスでの式にしたが、複数クラスの場合でも同様である。 この評価の値域は[0,1]でありground true(GT)をより正確に予測できるほど値が大きくなる。 そして従来は以下で定義するようなJaccard Lossを最適化するような方法も取られてる。 この指標の特徴としては、GT-pixelを全て予測できたとしても余分な予測(差集合)をしてしまったら、 その分だけ値が小さくなる点があげられる。つまり、予測しすぎをより防ぐような指標である。
何故Jaccard indexを直接最適化しないのか
先に述べたように、Jaccard Lossをとって最適化する方法は使われており、複数枚のデータの ピクセルから、いわばグローバルに計算することができ、またper-imageごとにも最適化 (比較的こっちのほうが精度は良いらしい)が可能だ。評価指標をそのまま最適化できるなら それが一番理想なのだが、これの欠点としてDeep-Learningのような大規模なパラメータの 最適化を行うのには適していないことがあげられる(IoU算出の際に予測値からmaskを計算するときに 離散的な遷移があるのが原因かと思われる)。segmentationのタスク上Deep-learningを用いない ことはありえないので、従来の研究ではしばしばbinary-cross-entropyが使用されている。
提案手法
Jaccard-Lossを指標とした最適化。うえであげた欠点を補うため、離散値であるこのLossをなめらかな連続空間で 表現できるよう工夫(Lovasz-extention)を加えた。 ※foregound-background segmentationのほうを扱っている。Multiclassの方はこれの拡張ととらえてもらえば。 大事なのはこの式:
これがLovasz-extention。は"vector of all pixel errors",は"Jaccard loss"。
codeをみつつこの式に戻ると理解が深まると思われる。
原論文:https://arxiv.org/pdf/1705.08790.pdf
原論文2: https://arxiv.org/abs/1512.07797
github:https://github.com/bermanmaxim/LovaszSoftmax
- 予測値のエラー計算
GTの値域を[0,1]から[1,2]に拡張し、エラー計算をする。エラーが正だと予測値が異なっていて、 負だと予測値があっていることを示している。error_sortedが上の式のmを表している。
signs = 2. * labels.float() - 1. errors = (1. - logits * Variable(signs)) errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
- Lovasz-extension
ここでの計算をしている。
gts = gt_sorted.sum() ntersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
3. Lossの算出
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
結果
データ:Pascal VOC, Network: DeeplabV2を用いBinary segmentationを行った。 以下のような結果になり、Lovasz-hinge(提案手法)をLoss関数として最適化をおこなった 場合に一番IoUが高い結果となった。
参考
Can someone please explain Lovasz loss? | Kaggle 【技術解説】集合の類似度(Jaccard係数,Dice係数,Simpson係数) - ミエルカAI は、自然言語処理技術を中心とした、RPA開発・サイト改善・流入改善レコメンドエンジンを開発