Convergence Lab.株式会社 CEOの木村優志です。今回は、PyTorchをもちいた、Deep Metric Learningの解説をします。ある程度、Deep Learningに詳しい読者を対象としています。

はじめに

Deep Metric Learningとは、その名の通り、Deep Learningを用いてMetric Learningを行う手法です。Metric Learningとは、タスクに特化した距離空間を学習する方法のことです。主に異常検知や話者識別などに用いられる手法です。

実際に、今回学習する、距離空間の例を見てみましょう。下の図は、今回扱うDeep Metric Learningで学習された 距離空間になります。これは犬の種類を距離空間にマッピングしたものです。マッピングしたのは、「パピヨン、チワワ、シェットランド・シープドッグ、コリー、ダルメシアン」の5種類です。下の図をみると、真ん中にシェットランド・シープドッグ、右上にコリー、右下にダルメシアン、下にチワワ、左側にパピヨンが位置するような距離空間ができています。

今回は、このような距離空間を学習する、Deep Metric Learningについて解説します。

なお、今回のコードは、https://github.com/Convergence-Lab-Inc/blog_metric_learning_explain から得られます。なお、データセットはスクレイピングして集めたものです、リポジトリには付属しておりません。

 Deep Metric Learning

 通常、Deep Learningモデルの出力はクラスのIDとなることが多いです。一方、Deep Metric Learningでは、入力データを距離空間に埋め込んだ時の座標地が出力となります。 

Deep Metric Learningのモデル

今回のDeep Learningモデルは以下のようになります。 今回は2次元空間に埋め込むため、metric_sizeは2次元です。basenetとして ResNet18を用いています。ここまでは、通常の方法と大きな違いはありません。なお、コード例はPyTorchを用いています。

import torch
from torch import nn

class DogModel(nn.Module):
    def __init__(self, metric_size=2):
        super(DogModel, self).__init__()
        self.basenet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
        self.basenet.fc = nn.Linear(512, metric_size)

    def __call__(self, x):
        x = self.basenet(x)
        return x

Deep Metric Learningのデータと損失関数

通常、よく見る学習のコードは以下のようなものでしょう。学習データ x を Deep Learningモデル net にとおした結果、predと labelとの間で lossをもとめます。

...
for batch in data_loader:
    x, label = batch
    pred = net(x)
    loss = criterion(x, label)
    ...

 一方、Deep Metric Learningでは、以下のようになります。 x, pos, neg の3つのデータを Deep Learning モデル net に通して、それぞれの結果から、 lossをもとめます。

...
for batch in tqdm(trainloader):
    x, pos, neg, label = batch
    x = net(x)
    pos = net(pos)
    neg = net(neg)
    loss = criterion(x, pos, neg)
...

ここで、 xのことを anchorと呼びます。今回学習したいデータのことです。pos は anchorと「近い」データのことです。ここでは単純に anchorと同じクラスの別のデータと考えて問題ありません。 negは anchorと「遠い」データです。ここでは単純に、anchorとは違うクラスのデータです。

たとえば、anchorがパピヨンの画像であった時、 posは同じパピヨンの画像、negはパピヨン以外のたとえばチワワなどの画像です。

これらをそれぞれ、netにとおして、座標データを得ます。

座標データをたあとは、 anchorとposは近くなるように、 anchorとnegは遠くなるように学習します。そのために、以下のような、Triplet Lossとよばれる損失関数を用います。

loss max{d(a,p)d(a,n)+margin,0}

ここで、aは anchor, p は positive, n は negative データです。d(・,・)は 入力の2つのデータの距離を計算する関数です。今回は 1 から cosine similarityを 引いた関数を用いました。コード上では、以下のように表されます。

def distance(x1, x2):
    return 1.0 - F.cosine_similarity(x1, x2, dim=1)

criterion = nn.TripletMarginWithDistanceLoss(distance_function=distance, margin=0.5)

あとは、通常のPyTorchのコードと大きな違いはありません。学習結果を図示してみましょう。今回は、以下のようなコードで学習結果の図示を行いました。冒頭に提示した画像を表示するコードです。

from model import DogModel
from data import DogTestDataset, test_collator

import numpy as np
from tqdm import tqdm
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image


def change_coordinate(x):
    x =  (x + 16) * 15
    x = x.astype(np.int32)
    # x = (x[0], x[0]+32, x[1]+32, x[1])
    return (x[0], x[1])

def predict():
    test_transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])

    test_dataset = DogTestDataset(transform=test_transform, train=False)
    testloader = DataLoader(test_dataset, batch_size=1, collate_fn=test_collator)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    net = DogModel(metric_size=2)
    checkpoint = torch.load("net.pt")
    net.load_state_dict(checkpoint["net"])
    net.to(device)

    im = Image.new(mode="RGBA", size=(512, 512), color=(255, 255, 255, 255))

    net.eval()
    test_loss = 0
    for batch in tqdm(testloader):
        label, x, img = batch
        with torch.no_grad():
            x = x.to(device)
            x = net(x)
        x = x.cpu().numpy()
        im2 = Image.open(img[0])
        im2 = im2.resize((64, 64)).convert("RGBA")
        coord = change_coordinate(x[0])
        Image.Image.paste(im, im2, coord)
    im.save("img.png")

if __name__ == "__main__":
    predict()

結果として、以下のよう距離空間が得られます。

 

Pin It
Keywords:

Related Articles/Posts

TAGS: