Pairwise Squared Euclidean Distance Loss function used in “Taming Transformers for high resolution images” paper explained


The code snippet using PyTorch library below is found in the Taming Transformers paper:

d = torch.sum(z_flatttened**2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight**2, dim=1) - \ 2 * torch.matmul(z_flatttened, self.embedding.weight.t())

This code snippet is performing a vectorized calculation to compute pairwise squared Euclidean distances between two sets of vectors: z_flattened and the rows of self.embedding.weight.

Let’s break down the code:

  1. z_flattened: This variable represents a 2D tensor (matrix) with shape (N, D), where N is the number of data points and D is the dimensionality of the data points.
  2. self.embedding.weight: This is a 2D tensor (matrix) with shape (M, D), where M is the number of reference points (e.g., in an embedding layer) and D is the dimensionality of the reference points.

Now let’s analyze the calculations:

  1. torch.sum(z_flattened**2, dim=1, keepdim=True): This line calculates the squared L2 norms of the z_flattened vectors along dimension 1 (columns). The result is a tensor of shape (N, 1).
  2. torch.sum(self.embedding.weight**2, dim=1): This line calculates the squared L2 norms of the self.embedding.weight vectors along dimension 1 (columns). The result is a tensor of shape (M,).
  3. 2 * torch.matmul(z_flattened, self.embedding.weight.t()): This line computes the matrix product of z_flattened and the transpose of self.embedding.weight. The result is a tensor of shape (N, M) containing the dot products between every pair of z_flattened and self.embedding.weight vectors. Then, it multiplies each element by 2.

Finally, the code adds the three terms calculated above, leveraging broadcasting to match the shapes of the tensors:

d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
    torch.sum(self.embedding.weight**2, dim=1) - \
    2 * torch.matmul(z_flattened, self.embedding.weight.t())

d is a tensor of shape (N, M) that contains the pairwise squared Euclidean distances between the data points in z_flattened and the reference points in self.embedding.weight.


How is this different from a simple l2 loss?

The code snippet provided calculates pairwise squared Euclidean distances between two sets of vectors, while the L2 loss, also known as Mean Squared Error (MSE) loss, calculates the average of the squared differences between two sets of vectors, usually representing predicted values and target values.

Here are the main differences:

  1. Purpose:
    • The code snippet calculates pairwise squared Euclidean distances, which can be useful in various contexts such as nearest neighbor search, clustering, or quantization.
    • L2 loss is a common loss function used in regression problems and neural networks to measure the difference between predicted values and target values. It is used for optimization during training to update the model’s parameters.
  2. Output:
    • The code snippet returns a 2D tensor (matrix) of shape (N, M) containing the pairwise squared Euclidean distances between the data points in z_flattened and the reference points in self.embedding.weight.
    • L2 loss returns a scalar value representing the average of the squared differences between two sets of vectors, typically predictions and ground truth.

To calculate the L2 loss between two sets of vectors, you would do the following in PyTorch:

import torch.nn as nn

predictions = ...  # A tensor of shape (N, D) containing predicted values
targets = ...      # A tensor of shape (N, D) containing target values

mse_loss = nn.MSELoss()
l2_loss = mse_loss(predictions, targets)

In this example, predictions and targets are both tensors of shape (N, D). The L2 loss (MSE loss) is calculated as the average of the squared differences between the corresponding elements in these two tensors.

In summary, the main difference between the code snippet and L2 loss is the purpose and the output they produce. The code snippet calculates pairwise squared Euclidean distances, while the L2 loss calculates the average squared difference between two sets of vectors.

What are examples of papers where this loss function is used?

The loss function in the code snippet calculates pairwise squared Euclidean distances, which is a fundamental concept in many machine-learning algorithms and techniques. While not necessarily used as a loss function, it is often employed as a distance metric or a part of various algorithms. Here are some examples of papers where this concept is used:

  1. “Product quantization for the nearest neighbor search” by Hervé Jégou, Matthijs Douze, and Cordelia Schmid (2011)
  2. “Large-scale image retrieval with compressed Fisher vectors” by Florent Perronnin, Yan Liu, and Jorge Sánchez (2010)
    • This paper uses squared Euclidean distance as a similarity measure to compare high-dimensional descriptors derived from local features (compressed Fisher vectors) for large-scale image retrieval.
    • Paper: https://www.di.ens.fr/willow/pdfs/cvpr10c.pdf
  3. “Approximate Nearest Neighbors: Towards Removing the Curse of Dimensionality” by Piotr Indyk and Rajeev Motwani (1998)

These papers utilize the concept of pairwise squared Euclidean distance in their proposed methods and algorithms.


Posted

in

,

by