An Empirical Study of Self-supervised Learning with
Wasserstein Distance
Abstract
In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model [12], and simplicial embedding [28]. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.
1 Introduction
Self-supervised learning (SSL) algorithms, including SimCLR [9], Bootstrap Your Own Latent (BYOL) [19], MoCo [9, 20], SwAV [5], SimSiam [8], and DINO [6], can also be regarded as unsupervised learning methods.
One of the main self-supervised algorithms adopts contrastive learning, in which two data points are systematically generated from a common data source, and lower-dimensional representations are found by maximizing the similarity between the positive pairs while minimizing the similarity between negative pairs. Depending on the context, positive and negative pairs can be defined differently. For example, in SimCLR [9], positive pairs correspond to images generated by independently applying different visual transformations, such as rotation and cropping. In multimodal learning, however, positive pairs are defined as the same examples corresponding in different modalities, such as images and text [22]. The flexibility of formulating positive and negative pairs also makes contrastive learning widely applicable beyond the image domain. This is a powerful pre-training method, because SSL does not require label information and can be trained using several data points.
In addition to contrastive learning-based SSL, non-contrastive approaches, such as BYOL [19], SwAV [5], and SimSiam [8] have been widely used. The fundamental concept of non-contrastive approaches involves the utilization of momentum and/or stop-gradient techniques to prevent mode collapse, as opposed to relying on negative sampling. Many of these approaches employ negative cosine similarity as a loss function. However, a limited number of SSL methods utilize distribution measures, such as cross-entropy, as exemplified by DINO [6], and simplicial embedding [28].
In this paper, leveraging the idea of distribution measures, for the first time we empirically investigate SSL performance using the Wasserstein distance. The Wasserstein distance, a widely adopted optimal transport-based distance for measuring distributional discrepancies, is useful in various machine-learning tasks, including generative adversarial networks [1], document classification [27, 40], image matching [31, 38], and algorithmic fairness [47, 51]. The 1-Wasserstein distance is also known as the earth mover’s distance (EMD) and the word mover’s distance (WMD) [27].
In this study, we consider an SSL framework with a 1-Wasserstein distance under a tree metric (i.e., Tree-Wasserstein distance (TWD)) [21, 30]. TWD includes the sliced Wasserstein distance [36, 24] and total variation as special cases, and can be represented by the distance between two vectors. Because TWD is given as a non-differentiable function, learning simplicial representations through backpropagation of TWD is challenging. Moreover, because the Wasserstein distance is computed from probability vectors, and several representations of probability vectors exist, it is difficult to determine which is most suitable for SSL training. Hence, we investigate a combination of probability models and the structure of TWD. Specifically, we consider the total variation and ClusterTree for TWD structure and show that the total variation is equivalent to a robust variant of TWD. In terms of the probability representations, we propose the combined use of softmax, an ArcFace-based probability model [12], and simplicial embedding (SEM) [28]. Finally, to stabilize the training, we propose a Jeffrey divergence-based regularization. Through SSL experiments, we find that the standard softmax formulation with backpropagation yields poor results. In particular, the non-contrastive SSL case fails to train the model with a simple combination of the Wasserstein distance and softmax function. For total variation, the ArcFace-based model performs well. By contrast, SEM is suitable for ClusterTree, whereas ArcFace-based models achieve modest performance. Moreover, the proposed regularization significantly outperforms its non-regularized counterparts.
Contribution: The contributions of this study are summarized below:
-
•
We investigate the combination of probability models and TWD (total variation and ClusterTree).
-
•
We propose a robust variant of TWD (RTWD) and show that RTWD is equivalent to total variation.
-
•
We propose the Jeffrey divergence regularization for TWD minimization, and find that the regularization significantly stabilizes training.
2 Related Work
The proposed method involves unsupervised representation learning and optimal transport.
Unsupervised Representation Learning: Representation learning is an important research topic in machine learning and involves several methods. The autoencoder [26] and its variational version [23] are widely employed in unsupervised representation learning methods. Current mainstream SSL approaches are based on a cross-view prediction framework [4] and contrastive learning has emerged as a prominent SSL paradigm.
In contastive learning, a model learns by contrasting positive samples (similar instances) with negative samples (dissimilar instances) using methods such as SimCLR [7]. SimCLR employs data augmentation and similarity metrics to encourage the model to project similar instances close together while pushing dissimilar instances apart. This approach has demonstrated efficacy across various domains, including computer vision and natural language processing, thus enabling learning without explicit labels. SimCLR employs the InfoNCE loss [34]. Subsequently to SimCLR, several alternative approaches have been proposed, including the use of Barlow’s twin [50]. The Barlow twin loss function attempts to maximize the correlation between positive pairs while minimizing the cross-correlation with negative samples. Barlow Twins is closely related to the Hilbert–Schmidt independence criterion, a kernel-based independence measure [18, 44].
One drawback of SimCLR is its reliance on numerous negative samples. To address this issue, recent research has focused on approaches that eliminate the need for negative sampling, such as BYOL [19], SwAV [5], and DINO [6]. For example, BYOL demonstrates training of representations by minimizing the loss between online and target networks. The target network is formed by maintaining a moving average of the online network parameters, and eliminates the need for negative samples. Surprisingly, BYOL showed favorable results compared with SimCLR. SimSiam, introduced by Chen and He [8], utilizes a Siamese network to train the estimation by fixing one of the networks using a stop gradient.
Both of these approaches concentrate on learning low-dimensional representations with real-valued vector embeddings by employing cosine similarity as a similarity measure in contrastive learning. Recently, Lavoie et al. [28] proposed simplicial embedding (SEM), which involves multiple concatenated softmax functions and learns high-dimensional sparse nonnegative representations. This innovation significantly enhances classification accuracy.
All of these approaches employ either a negative cosine similarity or cross-entropy as a loss function. In contrast, use of the Wasserstein distance in this context has not been studied.
Divergence and optimal transport: Measuring the divergence between two probability distributions is a fundamental research problem in machine learning. It has utility for various downstream applications, including document classification [27, 40], image matching [38], and algorithmic fairness [51, 47]. One widely adopted divergence measure is Kullback–Leibler (KL) divergence [10]. However, KL divergence can diverge to infinity when the supports of the two input probability distributions do not overlap.
The Wasserstein distance, or, as it is known in the computer vision community, EMD, can handle differences in supports between probability distributions. Another key advantages of the Wasserstein distance over KL is that it can identify matches between the data samples. For example, Sarlin et al. [38] proposed SuperGlue, leveraging optimal transport for correspondence determination in local feature sets. Liu et al. [31] proposed Semantic correspondence using optimal transport.
In NLP, Kusner et al. [27] introduced WMD, a Wasserstein distance pioneer in textual similarity tasks that is widely utilized, including for text generation evaluation [52]. Sato et al. [40] further studied the properties of WMD. Another interesting approach is the word rotator distance (WRD) [49], which utilizes the norm of word vectors as a simplicial representation and significantly improves WMD’s performance. However, these methods incur cubic-order computational costs, rendering them unsuitable for extensive distribution-comparison tasks.
The Wasserstein distance can be computed efficiently via linear programming (cubic complexity). However, to speed up EMD and Wasserstein distance computation, Cuturi [11] introduced the Sinkhorn algorithm, which solves the entropic regularized optimization problem and achieves quadratic order Wasserstein distance computation (), where is the number of data points. Moreover, because the optimal solution from the Sinkhorn algorithm can be obtained using an iterative algorithm, it can be easily incorporated into deep-learning applications, making optimal transport widely applicable. One limitation of the Sinkhorn algorithm is that it still requires quadratic-time computation, and the final solution depends highly on the regularization parameter.
An alternative approach is the sliced Wasserstein distance (SWD) [36, 24], which solves the optimal transport problem within a projected one-dimensional subspace. The algorithm for Wasserstein distance computation over reals essentially applies sorting as a subroutine; thus, SWD offers computation. SWD’s extensions include the generalized sliced Wasserstein distance for multidimensional cases [25]; the max-sliced Wasserstein distance, which determines the optimal transport-enhancing 1D subspace [33, 13]; and the subspace-robust Wasserstein distance [35].
The 1-Wasserstein distance with a tree metric (also known as the Tree-Wasserstein Distance (TWD)) is a generalization of the sliced Wasserstein distance and total variation [21, 15, 30]. The TWD is also known as the UniFrac distance [32] and is assumed to have a phylogenetic tree beforehand. An important property of TWD is that TWD has an analytical solution for the L1 distance of tree-embedded vectors.
Originally, TWD was studied in theoretical computer science, known as the QuadTree algorithm [21]. This has recently been extended by the ML community to include unbalanced TWD [39, 29], supervised Wasserstein training [41], and tree barycenters [42]. These approaches focus on approximating the 1-Wasserstein distance through tree construction and often utilize constant-edge weights. In terms of approaches that consider non-uniform edge weights, Backurs et al. [2] introduced FlowTree, amalgamating QuadTree and cost matrix methods, outperforming QuadTree. They guaranteed that QuadTree and FlowTree approximate nearest neighbors by employing the 1-Wasserstein distance. Dey and Zhang [14] proposed an L1-embedding for approximating the 1-Wasserstein distance for persistence diagrams. Finally, Yamada et al. [48] proposed a computationally efficient tree weight estimation technique for TWD and empirically demonstrated that TWD can attain comparable performance to the Wasserstein distance, while achieving computational speeds several orders of magnitude faster than linear programming computation of the Wasserstein distance.
Most existing studies on TWD have focused on tree construction [21, 30, 41] and edge weight estimation [48]. Takezawa et al. [42] proposed a Barycenter method based on TWD where the set of simplicial representations are given. Frogner et al. [17] and Toyokuni et al. [43] considered utilizing the Wasserstein distance for multi-label classification. These studies focused on supervised learning by employing softmax as the probability model. In this study, we investigate the Wasserstein distance by employing an SSL framework and evaluate various probability models.
3 Background
3.1 Self-supervised Learning methods
SimCLR [9]: Given input vectors , where , define the data transformation functions and . In the context of image applications, and can be understood as two image transformations over a given image: translation, rotation, blurring, etc. The neural network model is denoted as , where is a learnable parameter.
SimCLR attempts to train the model by learning features such that and are close after the feature mapping, while ensuring that both are distant from the feature map of , where is a negative sample generated from a different input image. To this end, InfoNCE loss [34] is employed in the SimCLR model:
where is the normaliser, is the batch size and is a similarity function that takes a higher positive value when and are similar and a smaller (positive or negative) value when and are dissimilar. is the temperature parameter, and is a delta function that takes a value of 1 when and 0 otherwise.
In SimCLR, the parameters are learned by minimizing the InfoNCE loss. Indeed, the numerator of the InfoNCE loss is proportional to the similarity between and . The denominator is a log-sum exp function and a softmax function. Because we attempt to minimize the maximum similarity between input and its negative samples, we can make and its negative samples dissimilar via the second term.
SimSiam [8]: SimSiam is a non-contrastive learning method; it does not use negative sampling to prevent mode collapse. In place of negative sampling, SimSiam employs a stop-gradient method. Specifically, the loss function is given by
where is the MLP head, is a latent variable, and is a latent variable with a stop gradient.
3.2 -Wasserstein distance
The -Wasserstein distance between two discrete measures, and is given by
where denotes the set of transport plans and . The Wasserstein distance can be computed using a linear program. However, because this includes an optimization problem, the computation of Wasserstein distance for each iteration is computationally expensive.
An alternative approach is entropic regularization [11]. If we consider the 1-Wasserstein distance, the entropic regularized variant is given as
This optimization problem can be solved efficiently using the Sinkhorn algorithm [11] at a computational cost of . More importantly, the solution of the Sinkhorn algorithm is given as a series of matrix multiplications. The Sinkhorn algorithm is widely employed in deep learning algorithms.
3.3 1-Wasserstein distance with tree metric (Tree-Wasserstein Distance)
Another 1-Wasserstein distance is based on trees [21, 30]. The 1-Wasserstein distance with the tree metric is defined as the L1 distance between two probability distributions and :
where is a tree parameter, if node is the ancestor node of leaf node and zero otherwise, is the total number of nodes of a tree, is the number of leaf nodes, and is the edge weight.
For illustration, we provide two examples to demonstrate the matrix by considering a tree with a depth of one and a ClusterTree, as shown in Figure 1. If all edge weights in the left panel of Figure 1, then the matrix is given as . By substituting this result into the TWD, we obtain
Thus, the total variation is a special case of TWD. In this setting, the shortest-path distance in the tree corresponds to the Hamming distance. Note that Raginsky et al. [37] also assert that the 1-Wasserstein distance with the Hamming metric is equivalent to the total variation (Proposition 3.4.1).
The key advantage of the tree-based approach is that the Wasserstein distance is written in closed form, which is computationally efficient. A chain is included as a special case in the tree. Thus, the widely employed sliced Wasserstein distance is also included as a special case of TWD (Figure 2). Moreover, it has been empirically reported that TWD- and Sinkhorn-based approaches perform similarly in multilabel classification tasks [43].
4 SSL with 1-Wasserstein Distance
In this section, we first formulate SSL using TWD. We then introduce ArcFace-based probability models and Jeffrey divergence regularization.
4.1 SimCLR with Tree Wasserstein Distance
Let and be the simplicial embedding vectors of and (i.e., and ) with and , respectively. Here, is the virtual embedding corresponding to or . is assumed unavailable in the problem setup. The main idea of this paper is to adopt the negative Wasserstein distance between and as the similarity score for SimCLR.
We assume that and are given; that is, both the tree structure and weights are known. In particular, we consider the trees shown in Figure 1.
Following the original design of the InfoNCE loss and by substituting the similarity score given by the negative Wasserstein distance, we obtain the following simplified loss function:
where is the temperature parameter for the InfoNCE loss 111Although we mainly focus on the InfoNCE loss, the proposed negative Wasserstein distance as a measure of similarity can be used in other contrastive losses as well, e.g., the Barlow Twins..
4.2 SimSiam with Tree Wasserstein Distance
Here, we consider a combination of SimSiam and TWD. The loss function of the proposed approach is expressed as
The distinction to the original SimSiam is that our formulation employs the Wasserstein distance, whereas the original formulation uses cosine similarity.
4.3 Robust Variant of Tree Wasserstein Distance
In our setup, it is difficult to estimate the tree structure and edge weight because the embedding vectors are unavailable. To address this problem, we consider a robust estimation of the Wasserstein distance, such as the subspace-robust Wasserstein distance (SRWD) [35], for TWD. The key idea of SRWD is to solve an optimal transport problem in a subspace in which the distance is maximized. In the TWD case, we can consider solving the optimal transport problem for the maximum shortest-path distance. Specifically, for a given , we propose the robust TWD (RTWD) as follows:
where , is the shortest-path distance between and , and and are embedded in a tree . This constraint implies that the weights of the ancestor node of leaf node are non-negative and sum to one.
Proposition 1
The robust variant of TWD (RTWD) is equivalent to total variation:
where denotes the total variation.
The proof is provided in Appendix. Based on this proposition, RTWD is equivalent to the total variation and does not depend on the tree structure . That is, if we do not have prior information about the tree structure, using the total variation is a reasonable choice.
4.4 Probability models
In this section, we discuss several choices of probability models for InfoNCE loss and SimSiam loss.
Softmax: The softmax function for simplicial representation is given by
where is a neural network model.
Simplicial Embedding: Lavoie et al. [28] proposed a simple yet efficient simplicial embedding method. Assume that the output dimensionality of a neural network model is . Then, SEM applies the softmax function to each -dimensional vector of , where we have probability vectors. The th softmax function is thus defined as follows:
where is the -th block of a neural network model. We normalize the softmax function by because must satisfy the sum-to-one constraint.
ArcFace model (AF): In comparison to SEM, we propose to employ the ArcFace probability model [12].
where is a learning parameter, is the normalized output of a model (), and is the temperature parameter. Note that AF has a structure similar to that of transformers [3, 46]. The key difference from the original notion of attention in transformers is the normalization of the key matrix and query vector .
AF with Positional Encoding: To the AF model, one can add one more linear layer and then apply the softmax function; then the output is similar to the standard softmax function. Here, we propose replacing the key matrix with a normalized positional encoding matrix ():
where and .
AF with Discrete Cosine Transform Matrix: Another natural approach would be to utilize an orthogonal matrix as . Therefore, we propose adopting a discrete cosine transform (DCT) matrix as . The DCT matrix is expressed as follows:
One of the contributions of this study is the finding that combining positional encoding and the DCT matrix with the ArcFace model significantly boosts performance, whereas the standard ArcFace model without these additions performs similarly to the softmax function.
4.5 Jeffrey-Divergence Regularization
We empirically observed that optimizing the loss function described above is extremely challenging. In particular, the L1 distance cannot be differentiated at 0. Figure 3(b) illustrates the learning curve for standard optimization using the softmax function model.
To stabilize optimization, we propose including the Jeffrey divergence (JD) as a regularization term. JD is an upper bound of the square of the 1-Wasserstein distance.
Proposition 2
For and probability vectors and , we have
where is a Jeffrey divergence.
This result indicates that minimizing the symmetric KL divergence (i.e., Jeffrey divergence) can minimize the tree-Wasserstein distance. Because the Jeffrey divergence is smooth, the computation of the gradient of the upper bound is easier. For presentation, we denote .
Note that Frogner et al. [17] considered a multilabel classification problem utilizing the regularized Wasserstein loss. They proposed utilizing Kullback–Leibler divergence-based regularization to stabilize training. We derive the Jeffrey divergence from the TWD, and JD regularisation includes a simple KL divergence-based regularization as a special case. Moreover, we propose employing JD regularization for SSL frameworks, which have not been extensively studied.
5 Experiments
This section evaluates SSL methods with different probability models.
5.1 Performance comparison for SimCLR
For all experiments, we employed the Resnet18 model with an output dimension of () and coded all the methods based on a standard SimCLR implementation 222https://github.com/sthalles/SimCLR. We used the Adam optimizer and set the learning rate to 0.0003, the weight decay parameter to 1e-4, and temperature to 0.07. For the proposed method, we compared two variants of TWD: total variation and ClusterTree ( Figure 1). As part of the model evaluation, we assessed the conventional softmax function, attention model (AF), and simplicial embedding (SEM) [28] and set the temperature parameter for all experiments. For SEM, we set and .
We also evaluated JD regularization, where we set the regularization parameter for all experiments. For reference, we compared cosine similarity as a similarity function of SimCLR. For all approaches, we utilized the KNN classifier of the scikit-learn package333https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html, where the number of nearest neighbor was set to . We utilized the L1 distance for Wasserstein distances and cosine similarity for non-probability-based models. All the experiments were computed on A6000 GPUs. We ran all experiments three times and report the average scores.
Similarity Function | probability model | STL10 | CIFAR10 | CIFAR100 | SVHN |
---|---|---|---|---|---|
Cosine Similarity | N/A | 75.77 0.47 | 67.39 0.46 | 32.06 0.06 | 76.35 0.39 |
Softmax | 70.12 0.04 | 63.20 0.23 | 26.88 0.26 | 74.46 0.62 | |
SEM | 71.33 0.45 | 61.13 0.56 | 26.08 0.07 | 74.28 1.13 | |
AF (DCT) | 72.95 0.31 | 65.92 0.65 | 25.96 0.13 | 76.51 0.24 | |
TWD (TV) | Softmax | 65.54 0.47 | 59.72 0.39 | 26.07 0.19 | 72.67 0.33 |
SEM | 65.35 0.31 | 56.56 0.46 | 24.31 0.43 | 73.36 1.19 | |
AF | 65.61 0.56 | 60.92 0.42 | 26.33 0.42 | 75.01 0.32 | |
AF (PE) | 71.71 0.17 | 64.68 0.33 | 26.38 0.37 | 76.44 0.45 | |
AF (DCT) | 73.28 0.27 | 67.03 0.24 | 25.85 0.39 | 77.62 0.40 | |
Softmax + JD | 72.64 0.27 | 67.08 0.14 | 27.82 0.22 | 77.69 0.46 | |
SEM + JD | 71.79 0.92 | 63.60 0.50 | 26.14 0.40 | 75.64 0.44 | |
AF + JD | 72.64 0.37 | 67.15 0.27 | 27.45 0.37 | 78.00 0.15 | |
AF (PE) + JD | 74.47 0.10 | 67.28 0.65 | 27.01 0.39 | 78.12 0.48 | |
AF (DCT) + JD | 76.28 0.07 | 68.60 0.36 | 26.49 0.24 | 79.70 0.23 | |
TWD (ClusterTree) | Softmax | 69.15 0.45 | 62.33 0.40 | 24.47 0.40 | 74.87 0.13 |
SEM | 72.88 0.12 | 63.82 0.32 | 22.55 0.28 | 77.47 0.92 | |
AF | 70.40 0.40 | 63.28 0.57 | 24.28 0.15 | 75.24 0.52 | |
AF (PE) | 72.37 0.28 | 65.08 0.74 | 23.33 0.35 | 76.67 0.26 | |
AF (DCT) | 71.95 0.46 | 65.89 0.11 | 21.87 0.19 | 77.92 0.24 | |
Softmax + JD | 73.52 0.16 | 66.76 0.29 | 24.96 0.07 | 77.65 0.53 | |
SEM + JD | 75.93 0.14 | 67.68 0.46 | 22.96 0.28 | 79.19 0.53 | |
AF + JD | 73.66 0.23 | 66.61 0.32 | 24.55 0.14 | 77.64 0.19 | |
AF (PE) + JD | 73.92 0.57 | 67.00 0.13 | 23.83 0.42 | 77.87 0.29 | |
AF (DCT) + JD | 74.29 0.30 | 67.50 0.49 | 22.89 0.12 | 78.31 0.72 |
Figure 3 illustrates the training loss and top-1 accuracy for the three methods: cosine + real-valued embedding, TV + Softmax, and TV + AF (DCT). This experiment revealed that the convergence speed of the loss function was nearly identical across all methods. Regarding the training top-1 accuracy, cosine + real-valued embedding achieves the highest accuracy, followed by the Softmax function, and AF (DCT) lags. This behavior is expected because real-valued embeddings offer the most flexibility, followed by Softmax, with AF models exhibiting the least freedom. For all methods based on the TWD, JD regularization significantly aids the training process, particularly in the case of the Softmax function. However, for AF (DCT), the improvement was relatively small. This is likely because AF (DCT) can also be considered a form of regularization.
Table 1 presents the experimental results for the test classification accuracy using KNN. The first observation is that the simple implementation of the conventional Softmax function performs poorly (the performance is approximately 10 points lower) compared to cosine similarity. As expected, AF has only one more layer than the simple Softmax model, and performs similarly to Softmax. Compared to Softmax and AF, AF (PE) and AF (DCT) significantly improve the classification accuracy for the total variation and ClusterTree cases. However, for the ClusterTree case, AF (PE) achieves a better classification performance, whereas the AF (DCT) improvement over the softmax model is limited. In the ClusterTree case, SEM significantly improves with the combination of ClusterTree and regularization.
Overall, the proposed method performs better than cosine similarity without real-valued vector embedding when the number of classes is relatively small (i.e., STL10, CIFAR10, and SVHN). By contrast, the performance of the proposed method degrades for CIFAR100, and the results for ClusterTree are particularly poor. As the Wasserstein distance can be minimized even if it cannot overfit, it is natural for the Wasserstein distance to make mistakes when the number of classes is large.
Next, we evaluated the Jeffrey divergence regularization. Surprisingly, simple regularization dramatically improves the classification performance of all the probability models. These results support the idea that the main problem with Wasserstein distance-based representation learning is its numerical instability.
Among the methods, the proposed AF (DCT) + JD with total variation achieves the highest classification accuracy, comparable to the cosine similarity result, and achieves more than 10% improvement from the naive implementation with the Softmax function. Moreover, all probability model performances with the cosine similarity combination tend to result in a lower classification error than those with the combination of the TWD and probability models. Based on our empirical study, we propose utilizing TWD (TV) + AF models or TWD (ClusterTree) + SEM for representation-learning tasks in probability-based representation learning.
Similarity | Probability model | Linear classifier |
---|---|---|
Cosine | N/A | 91.13 0.14 |
TWD (TV) | Softmax + JD | 9.99 0.00 |
AF (DCT) + JD | 90.60 0.02 |
5.2 Performance comparison for SimSiam
Next, we evaluated the performance using a non-contrastive setup. For all experiments, we utilized the Resnet18-Cifar-Variant1 model with an output dimension of () and implemented all methods based on a standard SimSiam framework444https://github.com/PatrickHua/SimSiam. The optimization was performed using the SGD optimizer with a base learning rate of 0.03, weight decay parameter of 0.00005, momentum parameter of 0.9, batch size of 512, and a fixed number of epochs set to 800. For the proposed method, we employed the total variation as a loss function, along with the softmax function and ArcFace model (AF). The temperature parameter was set to 0.1 for all experiments. Additionally, we assessed JD regularization with the regularization parameter set to 0.1 across all experiments. A100 GPUs were used for all experiments, and each experiment was run three times, with the reported results being the average scores.
We compared the proposed methods, TWDSimSiam (Softmax + JD) and TWDSimSiam (AF + JD), with the original SimSiam method which employs cosine similarity loss. Upon examination, we observe that learning the total variation with softmax encounters numerical issues, even with JD regularization (See Figures 3(a) and 3(c) in Appendix). Conversely, the AF + JD combination proved successful in training the models, as shown in Figures 3(b) and 3(c) in Appendix). One potential reason for the failure of TWD with Softmax is that the total variation can easily become zero because the softmax function lacks normalization. For TWDSimSiam (AF + JD), normalization within the AF model prevents convergence to a poor local minimum. From a performance standpoint, the utilization of cosine similarity and total variation (TV) yield comparable results. However, a key contribution of this study is the introduction of a practical approach to enhance the model training stability by incorporating Wasserstein distance, specifically through total variation. This discovery has a potential utility in various SSL tasks.
6 Conclusion
This study investigates SSL using TWD. We empirically evaluate several benchmark datasets and find that a simple combination of the softmax function and TWD performs poorly. To address this, we propose simplicial embedding [28] and ArcFace models [12] as probability models. Moreover, to mitigate the intricacies of optimizing TWD, we incorporate an upper bound on the squared 1-Wasserstein distance as a regularization technique. Overall, the combination of ArcFace and DCT outperforms their cosine similarity counterparts. Finally, we find that the combination of TWD (ClusterTree) and SEM yields favorable performance.
References
- Arjovsky et al. [2017] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In ICML, 2017.
- Backurs et al. [2020] Arturs Backurs, Yihe Dong, Piotr Indyk, Ilya Razenshteyn, and Tal Wagner. Scalable nearest neighbor search for optimal transport. In ICML, 2020.
- Bahdanau et al. [2014] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.
- Becker and Hinton [1992] Suzanna Becker and Geoffrey E Hinton. Self-organizing neural network that discovers surfaces in random-dot stereograms. Nature, 355(6356):161–163, 1992.
- Caron et al. [2020] Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, and Armand Joulin. Unsupervised learning of visual features by contrasting cluster assignments. NeurIPS, 2020.
- Caron et al. [2021] Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal, Piotr Bojanowski, and Armand Joulin. Emerging properties in self-supervised vision transformers. In ICCV, 2021.
- Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In ICML, 2020a.
- Chen and He [2021] Xinlei Chen and Kaiming He. Exploring simple siamese representation learning. In CVPR, 2021.
- Chen et al. [2020b] Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020b.
- Cover and Thomas [2012] Thomas M Cover and Joy A Thomas. Elements of information theory. John Wiley & Sons, 2012.
- Cuturi [2013] Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In NIPS, 2013.
- Deng et al. [2019] Jiankang Deng, Jia Guo, Niannan Xue, and Stefanos Zafeiriou. Arcface: Additive angular margin loss for deep face recognition. In CVPR, 2019.
- Deshpande et al. [2019] Ishan Deshpande, Yuan-Ting Hu, Ruoyu Sun, Ayis Pyrros, Nasir Siddiqui, Sanmi Koyejo, Zhizhen Zhao, David Forsyth, and Alexander G Schwing. Max-sliced Wasserstein distance and its use for GANs. In CVPR, 2019.
- Dey and Zhang [2022] Tamal K Dey and Simon Zhang. Approximating 1-wasserstein distance between persistence diagrams by graph sparsification. In ALENEX, 2022.
- Evans and Matsen [2012] Steven N Evans and Frederick A Matsen. The phylogenetic kantorovich–rubinstein metric for environmental sequence samples. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 74(3):569–592, 2012.
- Fan [1953] Ky Fan. Minimax theorems. Proceedings of the National Academy of Sciences, 39(1):42–47, 1953.
- Frogner et al. [2015] Charlie Frogner, Chiyuan Zhang, Hossein Mobahi, Mauricio Araya, and Tomaso A Poggio. Learning with a wasserstein loss. NIPS, 2015.
- Gretton et al. [2005] A. Gretton, O. Bousquet, Alex. Smola, and B. Schölkopf. Measuring statistical dependence with Hilbert-Schmidt norms. In ALT, 2005.
- Grill et al. [2020] Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Guo, Mohammad Gheshlaghi Azar, et al. Bootstrap your own latent-a new approach to self-supervised learning. NeurIPS, 2020.
- He et al. [2020] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. In CVPR, 2020.
- Indyk and Thaper [2003] Piotr Indyk and Nitin Thaper. Fast image retrieval via embeddings. In 3rd international workshop on statistical and computational theories of vision, volume 2, page 5. Nice, France, 2003.
- Jiang et al. [2023] Qian Jiang, Changyou Chen, Han Zhao, Liqun Chen, Qing Ping, Son Dinh Tran, Yi Xu, Belinda Zeng, and Trishul Chilimbi. Understanding and constructing latent modality structures in multi-modal representation learning. In CVPR, 2023.
- Kingma and Welling [2013] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
- Kolouri et al. [2016] Soheil Kolouri, Yang Zou, and Gustavo K Rohde. Sliced wasserstein kernels for probability distributions. In CVPR, 2016.
- Kolouri et al. [2019] Soheil Kolouri, Kimia Nadjahi, Umut Simsekli, Roland Badeau, and Gustavo Rohde. Generalized sliced wasserstein distances. In NeurIPS, 2019.
- Kramer [1991] Mark A Kramer. Nonlinear principal component analysis using autoassociative neural networks. AIChE journal, 37(2):233–243, 1991.
- Kusner et al. [2015] Matt Kusner, Yu Sun, Nicholas Kolkin, and Kilian Weinberger. From word embeddings to document distances. In ICML, 2015.
- Lavoie et al. [2022] Samuel Lavoie, Christos Tsirigotis, Max Schwarzer, Ankit Vani, Michael Noukhovitch, Kenji Kawaguchi, and Aaron Courville. Simplicial embeddings in self-supervised learning and downstream classification. arXiv preprint arXiv:2204.00616, 2022.
- Le and Nguyen [2021] Tam Le and Truyen Nguyen. Entropy partial transport with tree metrics: Theory and practice. In AISTATS, 2021.
- Le et al. [2019] Tam Le, Makoto Yamada, Kenji Fukumizu, and Marco Cuturi. Tree-sliced approximation of wasserstein distances. NeurIPS, 2019.
- Liu et al. [2020] Yanbin Liu, Linchao Zhu, Makoto Yamada, and Yi Yang. Semantic correspondence as an optimal transport problem. In CVPR, 2020.
- Lozupone and Knight [2005] Catherine Lozupone and Rob Knight. Unifrac: a new phylogenetic method for comparing microbial communities. Applied and environmental microbiology, 71(12):8228–8235, 2005.
- Mueller and Jaakkola [2015] Jonas W Mueller and Tommi Jaakkola. Principal differences analysis: Interpretable characterization of differences between distributions. NIPS, 2015.
- Oord et al. [2018] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748, 2018.
- Paty and Cuturi [2019] François-Pierre Paty and Marco Cuturi. Subspace robust wasserstein distances. In ICML, 2019.
- Rabin et al. [2011] Julien Rabin, Gabriel Peyré, Julie Delon, and Marc Bernot. Wasserstein barycenter and its application to texture mixing. In International Conference on Scale Space and Variational Methods in Computer Vision, pages 435–446. Springer, 2011.
- Raginsky et al. [2013] Maxim Raginsky, Igal Sason, et al. Concentration of measure inequalities in information theory, communications, and coding. Foundations and Trends® in Communications and Information Theory, 10(1-2):1–246, 2013.
- Sarlin et al. [2020] Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. Superglue: Learning feature matching with graph neural networks. In CVPR, 2020.
- Sato et al. [2020] Ryoma Sato, Makoto Yamada, and Hisashi Kashima. Fast unbalanced optimal transport on tree. In NeurIPS, 2020.
- Sato et al. [2022] Ryoma Sato, Makoto Yamada, and Hisashi Kashima. Re-evaluating word mover’s distance. ICML, 2022.
- Takezawa et al. [2021] Yuki Takezawa, Ryoma Sato, and Makoto Yamada. Supervised tree-wasserstein distance. In ICML, 2021.
- Takezawa et al. [2022] Yuki Takezawa, Ryoma Sato, Zornitsa Kozareva, Sujith Ravi, and Makoto Yamada. Fixed support tree-sliced wasserstein barycenter. AISTATS, 2022.
- Toyokuni et al. [2021] Ayato Toyokuni, Sho Yokoi, Hisashi Kashima, and Makoto Yamada. Computationally efficient Wasserstein loss for structured labels. In ECAL: Student Research Workshop, April 2021.
- Tsai et al. [2021] Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, and Ruslan Salakhutdinov. A note on connecting barlow twins with negative-sample-free contrastive learning. arXiv preprint arXiv:2104.13712, 2021.
- v. Neumann [1928] J v. Neumann. Zur theorie der gesellschaftsspiele. Mathematische annalen, 100(1):295–320, 1928.
- Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. NIPS, 2017.
- Xian et al. [2023] Ruicheng Xian, Lang Yin, and Han Zhao. Fair and optimal classification via post-processing. In ICML, 2023.
- Yamada et al. [2022] Makoto Yamada, Yuki Takezawa, Ryoma Sato, Han Bao, Zornitsa Kozareva, and Sujith Ravi. Approximating 1-wasserstein distance with trees. TMLR, 2022.
- Yokoi et al. [2020] Sho Yokoi, Ryo Takahashi, Reina Akama, Jun Suzuki, and Kentaro Inui. Word rotator’s distance. EMNLP, 2020.
- Zbontar et al. [2021] Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, and Stéphane Deny. Barlow twins: Self-supervised learning via redundancy reduction. In ICML, 2021.
- Zhao [2022] Han Zhao. Costs and benefits of fair regression. Transactions on Machine Learning Research, 2022.
- Zhao et al. [2019] Wei Zhao, Maxime Peyrard, Fei Liu, Yang Gao, Christian M Meyer, and Steffen Eger. MoverScore: Text generation evaluating with contextualized embeddings and earth mover distance. EMNLP-IJCNLP, 2019.
Appendix A Appendix
A.1 Proof of Proposition 1
(Proof) Let and . The shortest-path distance between leaves and can be represented as [48]
That is, is represented by a linear function with respect to for a given and the constraints on and are convex. Thus, strong duality holds, and we obtain the following representation using the minimax theorem [45, 16]:
where .
Without loss of generality, we consider . First, we rewrite the norm as
where denotes the set of descendants of node (including itself). Using the triangle inequality, we obtain
where is the set of ancestors of leaf (including itself). By rewriting the constraint as for any , we obtain:
The latter inequality holds for any weight vector . Therefore, considering the vector such that if and 0 otherwise, which satisfies the constraint , we obtain
This completes the proof of the proposition.
A.2 Proof of Proposition 2
(Proof) The following holds if with the probability vector (such that ).
Then, using Pinsker’s inequality, we derive the following inequalities:
and
Thus,
A.3 Ablation study
Similarity Function | STL10 | CIFAR10 | CIFAR100 | SVHN | |
---|---|---|---|---|---|
TWD (TV) | 73.28 0.27 | 67.03 0.24 | 25.85 0.39 | 77.62 0.40 | |
76.28 0.07 | 68.60 0.36 | 26.49 0.24 | 79.70 0.23 | ||
77.40 0.17 | 68.48 0.11 | 25.59 0.16 | 79.67 0.26 | ||
77.67 0.06 | 68.26 0.51 | 24.21 0.35 | 79.91 0.42 |
A.3.1 Effect of number of nearest neighbors
In this section, we assess the performance of the KNN model by varying the number of nearest neighbors and setting to 10 or 50. The results for are presented in Table 4, and Table 5 illustrates a comparison of the best models across different nearest neighbor values. Our experiments revealed that utilizing tends to enhance the performance, and the relative order of the results remains consistent, regardless of the number of nearest neighbors.
A.3.2 Effect of the regularization parameter for Jeffrey-Divergence
In this experiment, we evaluated model performance by varying the regularization parameter, denoted as . The results indicate a noteworthy improvement in performance with the introduction of regularization parameters. However, it was observed that the performance did not exhibit significant changes across different values of , and setting yielded favorable results.
Similarity Function | probability model | STL10 | CIFAR10 | CIFAR100 | SVHN |
---|---|---|---|---|---|
Cosine Similarity | N/A | 75.44 0.21 | 66.96 0.45 | 31.63 0.25 | 74.71 0.31 |
Softmax | 71.25 0.30 | 63.80 0.48 | 26.18 0.36 | 73.06 0.47 | |
SEM | 71.34 0.31 | 61.26 0.42 | 25.40 0.06 | 73.41 0.95 | |
AF (DCT) | 72.15 0.53 | 65.52 0.45 | 24.93 0.24 | 75.68 0.13 | |
TWD (TV) | Softmax | 63.42 0.24 | 59.03 0.58 | 24.95 0.31 | 70.87 0.29 |
SEM | 63.72 0.17 | 55.57 0.35 | 23.40 0.36 | 71.69 0.75 | |
AF | 63.97 0.05 | 59.96 0.44 | 25.29 0.17 | 73.44 0.35 | |
AF (PE) | 71.04 0.37 | 64.28 0.14 | 25.71 0.20 | 75.70 0.42 | |
AF (DCT) | 72.75 0.11 | 67.01 0.03 | 24.95 0.17 | 76.98 0.44 | |
Softmax + JD | 72.05 0.30 | 66.61 0.20 | 26.91 0.19 | 76.65 0.56 | |
SEM + JD | 70.73 0.89 | 62.75 0.61 | 24.83 0.27 | 74.71 0.43 | |
AF + JD | 71.74 0.19 | 66.74 0.20 | 26.68 0.35 | 77.10 0.04 | |
AF (PE) + JD | 74.10 0.20 | 66.82 0.36 | 26.17 0.00 | 77.55 0.50 | |
AF (DCT) + JD | 76.24 0.22 | 68.62 0.40 | 25.70 0.14 | 79.28 0.22 | |
TWD (Clust) | Softmax | 67.95 0.42 | 61.59 0.29 | 23.34 0.26 | 73.88 0.05 |
SEM | 72.43 0.11 | 63.63 0.42 | 21.29 0.28 | 77.04 0.77 | |
AF | 69.09 0.05 | 62.49 0.45 | 22.56 0.25 | 74.31 0.40 | |
AF (PE) | 72.08 0.07 | 64.56 0.31 | 22.51 0.29 | 75.98 0.23 | |
AF (DCT) | 71.64 0.15 | 65.51 0.36 | 21.04 0.10 | 77.59 0.25 | |
Softmax + JD | 73.07 0.13 | 66.38 0.27 | 23.97 0.11 | 76.82 0.50 | |
SEM + JD | 75.50 0.15 | 67.44 0.10 | 21.90 0.19 | 78.91 0.30 | |
AF + JD | 72.70 0.08 | 66.12 0.26 | 23.50 0.21 | 76.92 0.06 | |
AF (PE) + JD | 73.66 0.47 | 66.58 0.01 | 22.86 0.02 | 77.44 0.30 | |
AF (DCT) + JD | 73.79 0.12 | 67.34 0.38 | 21.96 0.34 | 78.00 0.60 |
Similarity Function | Nearest neighbors () | STL10 | CIFAR10 | CIFAR100 | SVHN |
---|---|---|---|---|---|
TWD (TV) | 76.24 0.22 | 68.62 0.40 | 25.70 0.14 | 79.28 0.22 | |
76.28 0.07 | 68.60 0.36 | 26.49 0.24 | 79.70 0.23 |