動機
- Auto-Encoderに最近興味があり試してみたかったから
- 画像を入力データとして異常行動を検知してみたかったから
- (World modelと関連があるから)
LSTMベースの異常検知アプローチ
以下の二つのアプローチがある(参考)
LSTMを分類器として、正常か異常の2値分類
これは単純に時系列データを与えて分類する方法である。このやり方は正常データ・異常データが共に十分に揃っているときに有効となる。LSTMを予測モデルとして、エラーの大きさで判断
今回行うやり方はこれである。時系列データの入力からこれはデータがあまり揃っていない場合に適している。
モデル詳細
今回は自分のパソコンのカメラからの映像に人が写り込んだことを検知しようと思う。データとしては、普段の部屋の動画を正常時のもの、僕が入り込んだ動画を異常として動画をとった。
全体構成
- Auto-Encoderによって画像データを低次元ベクトルに抽象化
- 抽象化された特徴ベクトルを入力とした予測モデル(LSTM)を学習
- 新規データに対するエラー値から正常異常判断
VAE
β-VAEを用いる。
まず、VAEの定義
class VAE(nn.Module): def __init__(self, z_dim): """ image shape: (3,128,128) """ super(VAE, self).__init__() ## Encoder self.conv1 = nn.Conv2d(3, 32, 4, stride=2) self.conv2 = nn.Conv2d(32, 64, 4, stride=2) self.conv3 = nn.Conv2d(64, 128, 4, stride=2) self.conv4 = nn.Conv2d(128, 256, 4, stride=2) ## Latent representation of mean and std self.fc1 = nn.Linear(256 * 6 * 6, z_dim) self.fc2 = nn.Linear(256 * 6 * 6, z_dim) self.fc3 = nn.Linear(z_dim, 256 * 6 * 6) ## Decoder self.deconv1 = nn.ConvTranspose2d(256 * 6 * 6, 128, 5, stride=2) self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) self.deconv3 = nn.ConvTranspose2d(64, 32, 5, stride=2) self.deconv4 = nn.ConvTranspose2d(32, 16, 6, stride=2) self.deconv5 = nn.ConvTranspose2d(16, 3, 6, stride=2) def encode(self, x): h = F.relu(self.conv1(x)) h = F.relu(self.conv2(h)) h = F.relu(self.conv3(h)) h = F.relu(self.conv4(h)) h = h.view(-1, 256 * 6 * 6) return self.fc1(h), self.fc2(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def decode(self, z): h = self.fc3(z).view(-1, 256 * 6 * 6, 1, 1) h = F.relu(self.deconv1(h)) h = F.relu(self.deconv2(h)) h = F.relu(self.deconv3(h)) h = F.relu(self.deconv4(h)) h = F.sigmoid(self.deconv5(h)) return h def forward(self, x, encode=True, mean=True): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) if encode: if mean: return mu return z return self.decode(z), mu, logvar
そしてLossの定義である。BETAは3で学習させた。
def loss_function(recon_x, x, mu, logvar): batch_size = recon_x.size()[0] BCE = F.binary_cross_entropy(recon_x.view(batch_size, -1), x.view(batch_size, -1), size_average=True) # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) BCE /= batch_size KLD /= batch_size return BCE + BETA * KLD
- LSTM
ベースはLSTMとして出力にMixture Density Networl(MDN)を組み込み実装した。
class LSTM(nn.Module): def __init__(self, sequence_len, hidden_units, z_dim, num_layers, n_gaussians, hidden_dim): super(LSTM, self).__init__() self.n_gaussians = n_gaussians self.num_layers = num_layers self.z_dim = z_dim self.hidden_dim = hidden_dim self.hidden_units = hidden_units self.sequence_len = sequence_len self.hidden = self.init_hidden(self.sequence_len) ## Encoding self.fc1 = nn.Linear(self.z_dim, self.hidden_dim) self.lstm = nn.LSTM(self.hidden_dim, hidden_units, num_layers) ## Output self.z_pi = nn.Linear(hidden_units, n_gaussians * self.z_dim) self.z_sigma = nn.Linear(hidden_units, n_gaussians * self.z_dim) self.z_mu = nn.Linear(hidden_units, n_gaussians * self.z_dim) def init_hidden(self, sequence): hidden = torch.zeros(self.num_layers, sequence, self.hidden_units, device=device) cell = torch.zeros(self.num_layers, sequence, self.hidden_units, device=device) return hidden, cell def forward(self, x): self.lstm.flatten_parameters() x = F.relu(self.fc1(x)) z, self.hidden = self.lstm(x, self.init_hidden(self.sequence_len)) sequence = x.size()[1] pi = self.z_pi(z).view(-1, sequence, self.n_gaussians, self.z_dim) pi = F.softmax(pi, dim=2) sigma = torch.exp(self.z_sigma(z)).view(-1, sequence, self.n_gaussians, self.z_dim) mu = self.z_mu(z).view(-1, sequence, self.n_gaussians, self.z_dim) return pi, sigma, mu
実験1
Auto-Encoderのみで異常検知を行う。正常時の画像のみをVAEで学習させ、テストデータとして異常も含まれた動画を入力した場合の損失値の変化から異常を検出する。
学習方法
正常時の画像のみを使って学習させた。この場合は僕の部屋の背景である。下の図のように再現できた。上の図が元画像で、下が復元した画像である。
結果
元画像とdecoderによって生成された画像とのLossの推移から異常を検知する。人が一端に写り込んだぐらいではエラーの極端な上昇は見られないが、人がカメラを支配するようになってきて初めてエラーが急増した。
実験2
Auto-Encoder+VAEで異常検知を行う。
学習方法
VAEは正常+異常データ双方を用いて学習させる。LSTMはVAEを通した特徴量を用い、正常画像のみを用いて学習させる。結果
人が途中で映り込む動画を入力させて、それのロスの推移を見た。実験1の時と同じようなグラフが得られた。
課題
学習データが少ないため、信頼性のある結果が得られなかった
正常時以外の物体が写り込んでもエラーの突発的な増加は見られなかった
明らかに変な映像となった時にエラーが急増することがわかったので、それぞれのモデルの学習をよりうまくやることで精度を上げられるかも
参考
コード参考:https://dylandjian.github.io/world-models/
時系列異常検知の論文:https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2015-56.pdf
参考サイト:https://www.quora.com/How-do-I-use-LSTM-Networks-for-time-series-anomaly-detection