0. Background
Distance metric learning의 예제로써 'Large Margin Nearest Neighbor (LMNN) classification'을 살펴보자.
- 해당 논문은 2005년 NIPS에 발표되었다. [Link]
LMNN의 training 과정:
- 각 입력 $\vec{x}_i$에 대해, 같은 class label $y_i$을 가진 점들 중 $k$개의 target neighbors를 식별한다. (prior knowledge가 없을 경우 Euclidean distance 사용)
- Target neighbors가 모든 다른 labels을 가진 점들보다 $\vec{x}_i$에 가깝게 측정되는 distance metric을 찾는다.
Distance metric 추정을 위해 Mahalanobis distance metric을 다음과 같이 표현하여 convex 문제로 푼다.
임의의 두점 $\vec{x}_i, \vec{x}_j$에 대해,
$$ \begin{equation} d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j)= (\vec{x}_i - \vec{x}_j)^\top \mathbf{M} (\vec{x}_i - \vec{x}_j) \end{equation} \tag{1}$$
여기서 $\mathbf{M}$이 항등행렬일 경우 Eq. (1)은 유클리디안 거리가 된다.
$\mathbf{M}$을 추정하기 위해 다음과 같이 semidefinite program (SDP) 문제를 정의한다.
Minimize $\sum_{j \leadsto i} [ d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j)+\mu \sum_l (1-y_{il}) \xi_{ijl} ]$ subject to: (a) $d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_l) - d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j) \geq 1- \xi_{ijl}$ (b) $\xi_{ijl} \geq 0$ (c) $\mathbf{M} \succeq 0$ |
- $j \leadsto i$는 $\vec{x}_j$가 $\vec{x}_i$의 target neighbor라는 것을 의미
- $\vec{x}_i$와 $\vec{x}_j$가 같은 label이면 $y_{ij}$는 1, 아니면 0.
- $\xi_{ijl}$는 $\vec{x}_i$을 기준으로 그것의 target neighbor $\vec{x}_j$ 들이 형성하는 어떤 경계선안에 다른 label을 가진 $\vec{x}_l$이 얼마나 있는지를 측정한다. (추후 다시 설명)
결국 loss function은 target neghbors은 가깝게 (first term), 서로다른 label의 점들끼린 멀게 (second term, margin을 높이면서) 측정되는 $\mathbf{M}$을 찾는것을 목적으로 한다.
1. Introduction
[Problem]: 기존 LMNN방법은 all training instances을 다 체크하고 (입력 공간 차원 $\times$ 입력 공간 차원)의 $\mathbf{M}$을 가지므로 large scale datasets에 적합하지 않음 (너무 오래걸림)
[Contribution]:
- SDP 문제를 효율적으로 풀기 위해 특정한 instances만 체크하는 solver를 제안하여 large scale datasets에 대한 학습을 가능하게 함
- Ball tree + low-rank approximation을 사용하여 training time과 testing time 감소
- 입력 공간을 여러 부분공간으로 나누고 각각에 대해 distance metric을 최적화함으로써 error rate 감소
2. Methodology
2.1. Solver
[목적]: SDP 과정을 통한 $\mathbf{M}$의 탐색 과정에서 일부 instances에 대한 sub-gradient를 계산하여 효율성 향상
위에서 소개한 SDP의 목적 함수를 $\mathbf{M}$의 함수로 표현할 수 있도록 먼저 slack 변수를 다음과 같이 표현한다.
$$ \begin{equation} \xi_{ijl}(\mathbf{M}) = [1+d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j)-d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_l)]_+ \end{equation} \tag{2}$$
여기서 $[z]_+$는 $z>0$일 때 $z$이고 $z<0$일 때 0인 함수이다.
→ (다른 label을 가진 점과의 거리 - target neighbor과의 거리)가 1보다 작아야 값을 가진다. "Margin violation"
Slack 변수를 Eq. (2)로 대체하는 것에 의해 목적함수는 다음과 같이 formulation된다.
$$ \begin{equation} \varepsilon(\mathbf{M})=\sum_{j \leadsto i} [ d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j)+\mu \sum_l (1-y_{il}) \xi_{ijl}(\mathbf{M}) ] \end{equation} \tag{3}$$
→ Second term은 다른 label을 가진 점들 중 $\vec{x}_i,$를 기준으로 (다른 label을 가진 점과의 거리 - target neighbor과의 거리)가 1보다 작을 경우에만 $\mathbf{M}$을 업데이트 한다.
→ 즉 서로 다른 label들 간의 점 사이의 "margin"을 1이상 가지도록하는 distance metric을 찾는다.
Eq. (3)은 미분 불가능하지만 convex이기 때문에, sub-gradient를 계산하는 것에 의해 standard descent algorithm을 사용할 수 있다.
먼저, Eq. (1)은 다음과 같이 나타낸다.
$$ \begin{equation} d^2_{\mathbf{M}}(\vec{x}_i, \vec{x}_j)= tr(\mathbf{C}_{ij}\mathbf{M}) \end{equation} \tag{4}$$
여기서 $\mathbf{C}_{ij} = (\vec{x}_i - \vec{x}_j)(\vec{x}_i - \vec{x}_j)^\top$이다.
각 iteration에서 $\mathcal{N}^t$를 triplet indices의 집합이라 하자. s.t. $(i,j,l)\in \mathcal{N}^t, \xi_{ijl}(\mathbf{M}^t)>0$.
이때, $t^{th}$ iteration에서, gradient $\mathbf{G}^t = \frac{\partial \varepsilon}{\partial \mathbf{M}}\vert_{\mathbf{M}^t}$는 다음과 같이 계산된다.
$$ \begin{equation} \mathbf{G}^t = \sum_{j \leadsto i} \mathbf{C}_{ij} + \mu \sum_{(i,j,l)\in\mathcal{N}^t} ( \mathbf{C}_{ij}-\mathbf{C}_{il})\end{equation} \tag{5}$$
여기서 margin violation이 얼만큼 되었는지에 대한 degree는 gradient값에 영향을 주지 않는다.
→ Eq. (4)를 $\mathbf{M}$에 관해 미분하면 0이됨
따라서 iteration이 다음으로 넘어갈때, gradient의 변화는 $\mathcal{N}^t$와 $\mathcal{N}^{t+1}$ 사이의 차에 의해 결정된다. 즉,
$$ \begin{equation} \mathbf{G}^{t+1} = \mathbf{G}^t - \mu \sum_{(i,j,l)\in\mathcal{N}^t-\mathcal{N}^{t+1}} ( \mathbf{C}_{ij}-\mathbf{C}_{il}) + \mu \sum_{(i,j,l)\in\mathcal{N}^{t+1}-\mathcal{N}^t} ( \mathbf{C}_{ij}-\mathbf{C}_{il}) \end{equation} \tag{6}$$
결과적으로 $\mathcal{N}$이 변할때 해당되는 instances에 대해서만 계산하여 업데이트가 가능하다.
2.2. Tree-Based Search
Nearest neighbor 탐색 과정은 instances을 트리같은 계층적 데이터 구조들로 저장하는 것에 의해 가속화될 수 있다.
이 논문에서는 ball tree방법을 적용하였다. (잘 알려진 방법이므로 위키로 대체 [Link])
추가로 SDP 과정에서 $\mathbf{M}$의 low-rank approximation을 사용하여 더욱 speedup시켰다.
2.3. Multiple Metrics
Error rate를 줄이기 위해 class label마다 서로 다른 distance metric을 사용한다.
자세한 실험들은 본논문 참고 [Link]