公開機械学習錬金術からの記事
focal loss
Focal Loss関数 密集物体検出タスクのためにFocal Lossが提案されています。
もちろん、ターゲット検出では、検出すべき物体のカテゴリが1000あるかもしれませんが、識別したい物体はカテゴリの1つでしかありません。
そしてフォーカルロスは、サンプル数が極端に偏っているという問題の解決策に過ぎません。
サンプルの不均衡の解決というと、混同行列のf1-scoreをご存知かと思いますが、これは訓練での損失としては使えないようです。そして、フォーカルロスは、少数のターゲットカテゴリが重みを増加させるように、誤分類されたサンプルが重みを増加させるように、トレーニングで使用することができます。
まず、単純なバイナリー・クロスエントロピーの損失を見てみましょう:
- y' はモデルによって与えられる予測カテゴリ確率で, y は真の標本.つまり,標本の真のカテゴリが1で予測確率が0.9の場合,-logがこの損失.
- 公平を期すため、私は一般的にバイナリー・クロスエントロピーを例として使うのは好きではなく、多カテゴリー・クロスエントロピーを例として使う方がしっくりきます。
フォーカルロスのさらなる改善:ここでアルファ値が追加されます。このアルファ値は論文では0.25とされていますが、これはサンプルの不均衡の問題を解決するために、単純に正または負のサンプルの重みを減らすためです。
この2つの組み合わせは、サンプルの不均衡問題を解決できる損失FOCAL LOSSです。
概要
- ααはサンプルの不均衡問題を解決します;
- ββはハードサンプルとイージーサンプルのアンバランスを解消します。ハードなサンプルを重視し、イージーなサンプルを無視します。
- 要するに、フォーカルロスは、「サンプル数が少ないと分類しにくい」、「サンプル数が多いと分類しにくい」、「サンプル数が少ないと分類しやすい」、「サンプル数が多いと分類しやすい」という順番で懸念されることになります。
GHM
- GHMとはGradient Harmonising Mechanismの略。
このGHMは、フォーカルロスに存在する問題のいくつかを解決するために設計されています。
フォーカルロスの欠点1特に分類が難しいサンプルにモデルが集中しすぎることは問題になることがあります。サンプルには異常値や外れ値があります。そのため、モデルは、これらの非常に適合しにくい異常値に適合させるために、オーバーフィッティングする危険性があります。
GHM
Focal Lossは信頼度pの観点から損失を減衰させます。一方、GHMは信頼度pのある範囲内で損失を減衰させるサンプル数です。
まず、勾配係数と呼ばれる変数gが定義されます:この勾配係数は、実際、モデルによって与えられた信頼度p*p*と、このサンプルの真のラベルとの差であることがわかります。gが小さいほど、予測はより正確であり、サンプルが分類しやすいことを示します。
gとサンプル数の関係を下図に示します:
グラフからわかるように
- 勾配係数が0に近いサンプルが多い、つまり分類しやすいサンプルが非常に多い。
- その後、勾配モードの長さが長くなるにつれて、サンプル数は急速に減少します。
- その後、勾配モードの長さが1に近づくにつれて、サンプル数は再び増加し始めます。
GHMは勾配モード長の小さい分類しやすいサンプルを無視することを考えていますが、FOCAL LOSSは分類しにくいサンプルに焦点を当てすぎています。重要なのは、実際には分類しにくいサンプルがたくさんあるということです!もしモデルが分類しにくいサンプルを学習し続ければ、モデルの精度は下がるかもしれません。ですから、GHMは分類しにくいサンプルに対しても減衰があります。
そうすると、GHMは分類しやすいサンプルと分類しにくいサンプルの両方を減衰させるので、本当に注目されるサンプルは分類しにくく、分類しにくいサンプルということになります。また、抑制の度合いはサンプル数で決めることができます。
ここではGD(Gradient Density)を定義しています:
G
- G
D GDは、勾配gの位置での勾配密度を計算します; - δ
δはサンプルkの勾配gkgkが区間[g-2ϵ,g+2ϵ]にあるかどうかです。 - l lは区間[g-2ϵ,g+2ϵ]の長さ、すなわちϵϵ。
つまり、GDは-ϵ,g+ϵ[g-2ϵg+ϵ]内の勾配係数を持つサンプルの総数をϵで割ったものです。
次に、各サンプルのクロスエントロピー損失を対応する勾配密度で割るだけです。L
- C
E CEはi番目のサンプルのクロスエントロピー損失を表します; - G
D GDはi番目のサンプルの勾配密度を表します;
論文におけるGHM
論文については、勾配モードの長さは、信頼水準pが0から1であるため、10領域に分割されているので、勾配密度の領域の長さは0.1であるため、例えば、領域の0から0.1です。
以下は論文にある比較表:
グラフから
- 緑はクロスエントロピー損失;
- 青色は焦点損失で、勾配モード長を小さくした損失減衰が有効であることがわかりました;
- 赤はGHMのクロスエントロピー損失で、0付近と1付近の勾配係数の著しい減衰が見られます。
もちろん、GHMが更新する勾配密度を計算するために、サンプル全体のモデル推定を必要とするように見えることは考えられます。つまり、ミニバッチはGHMと一緒に使うことはできないようです。
これはGHMのオリジナル記事にも書かれていることですが、ミニバッチを単独で使用した場合、バランスが崩れる可能性があります。
私が個人的に感じているハンドリング
- このエポックでは、前のエポックの勾配密度を使用できます;
- あるいは、最初にミニバッチを使って勾配密度を計算し、モデルの収束率が低下した後に最初のアプローチを使って更新することもできます。
python
上記の物語で重要なのは、焦点損失によって達成される機能性です:
- 正しく分類されたサンプルの損失重みは小さく、間違って分類されたサンプルの損失重みは大きくなります。
- オーバーサンプリングされたカテゴリのウェイトが低い
Focal Lossは、CenterNetのセントロイドの位置を予測するときにも使われますが、少し修正されています。
概要
ここではフットマーカーを無視して、上で話したことに近いです。
Y=Y=1とすると、予測値Y^が1に近いほど、予測がほぼ正しいことを示し、αは小さくなり、正しく分類されたサンプルの損失重みが小さいことを反映します。 - センターネットのヒートマップについては、後ほどコードセクションで説明します。そうすれば、このことが理解できるでしょう。
コード説明
コードで理解する方法をご紹介します:
class FocalLoss(nn.Module):
def __init__(self):
super().__init__()
self.neg_loss = _neg_loss
def forward(self, output, target, mask):
output = torch.sigmoid(output)
loss = self.neg_loss(output, target, mask)
return loss
ここでの出力は1チャンネルの特徴マップとして解釈でき、各ピクセルの値はモデルによって与えられた信頼度であり、シグモイド関数によって0から1の区間の信頼度に変換されます。
そして、ターゲットはセンターネットのヒートマップです。例えば、10×10の特徴マップが全て0だとすると、この特徴マップの中で1つの画素だけが1だとすると、この画素の位置が検出対象の中心になります。1がいくつあるということは、このマップの中に検出すべき対象物が複数あるということです。
もし特徴マップがゼロでいっぱいで、1が少ししかないような場合、それはあまりにまばらで、直感的に非常に滑らかではありません。CenterNetのヒートマップは、この1を中心としたガウシアンも必要で、これはスムージングとして見ることができます。これは1中心のガウス平滑化です。
ここで上記のβに戻ります:数字の1については、最初の行を使って損失を計算するのが自然ですが、1に近い他の点については、βを考慮するときが来ます。
本題に戻りますが、出力をシグモイドした後、出力と一緒にneg_lossに入れます。neg_lossとは?
def _neg_loss(pred, gt, mask):
pos_inds = gt.eq(1).float() * mask
neg_inds = gt.lt(1).float() * mask
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
まず、この中のマスクは、特定のタスクに追加された小さな機能で、そのタスクではイメージの中に損失計算をする必要のない部分があるので、その部分をフィルタリングするために最初にマスクを使うという事実に基づいています。ここではマスクは無視してください。
neg_weights = torch.pow(1 - gt, 4)
β=β=4であることがわかり、下のコードからα=α=2であることを推測するのは難しくありません。
各ピクセルの損失を合計し、ターゲットオブジェクトの数で割るだけです。