HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: quoting

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY 4.0
arXiv:2310.10143v2 [stat.ML] 05 Feb 2024

An Empirical Study of Self-supervised Learning with
Wasserstein Distance

Makoto Yamada1,212{}^{1,2}start_FLOATSUPERSCRIPT 1 , 2 end_FLOATSUPERSCRIPT, Yuki Takezawa1,313{}^{1,3}start_FLOATSUPERSCRIPT 1 , 3 end_FLOATSUPERSCRIPT, Guillaume Houry1,414{}^{1,4}start_FLOATSUPERSCRIPT 1 , 4 end_FLOATSUPERSCRIPT, Kira Michaela Düsterwald55{}^{5}start_FLOATSUPERSCRIPT 5 end_FLOATSUPERSCRIPT
Deborah Sulem66{}^{6}start_FLOATSUPERSCRIPT 6 end_FLOATSUPERSCRIPT, Han Zhao77{}^{7}start_FLOATSUPERSCRIPT 7 end_FLOATSUPERSCRIPT, Yao-Hung Hubert Tsai88{}^{8}start_FLOATSUPERSCRIPT 8 end_FLOATSUPERSCRIPT
11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPTOkinawa Institute of Science and Technology, 22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPTRIKEN AIP, 33{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPTKyoto University
44{}^{4}start_FLOATSUPERSCRIPT 4 end_FLOATSUPERSCRIPTEcole Normale Superieure, 55{}^{5}start_FLOATSUPERSCRIPT 5 end_FLOATSUPERSCRIPTUniversity College London, 66{}^{6}start_FLOATSUPERSCRIPT 6 end_FLOATSUPERSCRIPT Universitat Pompeu Fabra,
77{}^{7}start_FLOATSUPERSCRIPT 7 end_FLOATSUPERSCRIPTUniversity of Illinois at Urbana-Champaign, 88{}^{8}start_FLOATSUPERSCRIPT 8 end_FLOATSUPERSCRIPTCarnegie Mellon University
(February 5, 2024)
Abstract

In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model [12], and simplicial embedding [28]. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.

1 Introduction

Self-supervised learning (SSL) algorithms, including SimCLR [9], Bootstrap Your Own Latent (BYOL) [19], MoCo [9, 20], SwAV [5], SimSiam [8], and DINO [6], can also be regarded as unsupervised learning methods.

One of the main self-supervised algorithms adopts contrastive learning, in which two data points are systematically generated from a common data source, and lower-dimensional representations are found by maximizing the similarity between the positive pairs while minimizing the similarity between negative pairs. Depending on the context, positive and negative pairs can be defined differently. For example, in SimCLR [9], positive pairs correspond to images generated by independently applying different visual transformations, such as rotation and cropping. In multimodal learning, however, positive pairs are defined as the same examples corresponding in different modalities, such as images and text [22]. The flexibility of formulating positive and negative pairs also makes contrastive learning widely applicable beyond the image domain. This is a powerful pre-training method, because SSL does not require label information and can be trained using several data points.

In addition to contrastive learning-based SSL, non-contrastive approaches, such as BYOL [19], SwAV [5], and SimSiam [8] have been widely used. The fundamental concept of non-contrastive approaches involves the utilization of momentum and/or stop-gradient techniques to prevent mode collapse, as opposed to relying on negative sampling. Many of these approaches employ negative cosine similarity as a loss function. However, a limited number of SSL methods utilize distribution measures, such as cross-entropy, as exemplified by DINO [6], and simplicial embedding [28].

In this paper, leveraging the idea of distribution measures, for the first time we empirically investigate SSL performance using the Wasserstein distance. The Wasserstein distance, a widely adopted optimal transport-based distance for measuring distributional discrepancies, is useful in various machine-learning tasks, including generative adversarial networks [1], document classification [27, 40], image matching [31, 38], and algorithmic fairness [47, 51]. The 1-Wasserstein distance is also known as the earth mover’s distance (EMD) and the word mover’s distance (WMD) [27].

In this study, we consider an SSL framework with a 1-Wasserstein distance under a tree metric (i.e., Tree-Wasserstein distance (TWD)) [21, 30]. TWD includes the sliced Wasserstein distance [36, 24] and total variation as special cases, and can be represented by the 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT distance between two vectors. Because TWD is given as a non-differentiable function, learning simplicial representations through backpropagation of TWD is challenging. Moreover, because the Wasserstein distance is computed from probability vectors, and several representations of probability vectors exist, it is difficult to determine which is most suitable for SSL training. Hence, we investigate a combination of probability models and the structure of TWD. Specifically, we consider the total variation and ClusterTree for TWD structure and show that the total variation is equivalent to a robust variant of TWD. In terms of the probability representations, we propose the combined use of softmax, an ArcFace-based probability model [12], and simplicial embedding (SEM) [28]. Finally, to stabilize the training, we propose a Jeffrey divergence-based regularization. Through SSL experiments, we find that the standard softmax formulation with backpropagation yields poor results. In particular, the non-contrastive SSL case fails to train the model with a simple combination of the Wasserstein distance and softmax function. For total variation, the ArcFace-based model performs well. By contrast, SEM is suitable for ClusterTree, whereas ArcFace-based models achieve modest performance. Moreover, the proposed regularization significantly outperforms its non-regularized counterparts.

Contribution: The contributions of this study are summarized below:

  • We investigate the combination of probability models and TWD (total variation and ClusterTree).

  • We propose a robust variant of TWD (RTWD) and show that RTWD is equivalent to total variation.

  • We propose the Jeffrey divergence regularization for TWD minimization, and find that the regularization significantly stabilizes training.

2 Related Work

The proposed method involves unsupervised representation learning and optimal transport.

Unsupervised Representation Learning: Representation learning is an important research topic in machine learning and involves several methods. The autoencoder [26] and its variational version [23] are widely employed in unsupervised representation learning methods. Current mainstream SSL approaches are based on a cross-view prediction framework [4] and contrastive learning has emerged as a prominent SSL paradigm.

In contastive learning, a model learns by contrasting positive samples (similar instances) with negative samples (dissimilar instances) using methods such as SimCLR [7]. SimCLR employs data augmentation and similarity metrics to encourage the model to project similar instances close together while pushing dissimilar instances apart. This approach has demonstrated efficacy across various domains, including computer vision and natural language processing, thus enabling learning without explicit labels. SimCLR employs the InfoNCE loss [34]. Subsequently to SimCLR, several alternative approaches have been proposed, including the use of Barlow’s twin [50]. The Barlow twin loss function attempts to maximize the correlation between positive pairs while minimizing the cross-correlation with negative samples. Barlow Twins is closely related to the Hilbert–Schmidt independence criterion, a kernel-based independence measure [18, 44].

One drawback of SimCLR is its reliance on numerous negative samples. To address this issue, recent research has focused on approaches that eliminate the need for negative sampling, such as BYOL [19], SwAV [5], and DINO [6]. For example, BYOL demonstrates training of representations by minimizing the loss between online and target networks. The target network is formed by maintaining a moving average of the online network parameters, and eliminates the need for negative samples. Surprisingly, BYOL showed favorable results compared with SimCLR. SimSiam, introduced by Chen and He [8], utilizes a Siamese network to train the estimation by fixing one of the networks using a stop gradient.

Both of these approaches concentrate on learning low-dimensional representations with real-valued vector embeddings by employing cosine similarity as a similarity measure in contrastive learning. Recently, Lavoie et al. [28] proposed simplicial embedding (SEM), which involves multiple concatenated softmax functions and learns high-dimensional sparse nonnegative representations. This innovation significantly enhances classification accuracy.

All of these approaches employ either a negative cosine similarity or cross-entropy as a loss function. In contrast, use of the Wasserstein distance in this context has not been studied.

Divergence and optimal transport: Measuring the divergence between two probability distributions is a fundamental research problem in machine learning. It has utility for various downstream applications, including document classification [27, 40], image matching [38], and algorithmic fairness [51, 47]. One widely adopted divergence measure is Kullback–Leibler (KL) divergence [10]. However, KL divergence can diverge to infinity when the supports of the two input probability distributions do not overlap.

The Wasserstein distance, or, as it is known in the computer vision community, EMD, can handle differences in supports between probability distributions. Another key advantages of the Wasserstein distance over KL is that it can identify matches between the data samples. For example, Sarlin et al. [38] proposed SuperGlue, leveraging optimal transport for correspondence determination in local feature sets. Liu et al. [31] proposed Semantic correspondence using optimal transport.

In NLP, Kusner et al. [27] introduced WMD, a Wasserstein distance pioneer in textual similarity tasks that is widely utilized, including for text generation evaluation [52]. Sato et al. [40] further studied the properties of WMD. Another interesting approach is the word rotator distance (WRD) [49], which utilizes the norm of word vectors as a simplicial representation and significantly improves WMD’s performance. However, these methods incur cubic-order computational costs, rendering them unsuitable for extensive distribution-comparison tasks.

The Wasserstein distance can be computed efficiently via linear programming (cubic complexity). However, to speed up EMD and Wasserstein distance computation, Cuturi [11] introduced the Sinkhorn algorithm, which solves the entropic regularized optimization problem and achieves quadratic order Wasserstein distance computation (O(n¯2)𝑂superscript¯𝑛2O(\bar{n}^{2})italic_O ( over¯ start_ARG italic_n end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )), where n¯¯𝑛\bar{n}over¯ start_ARG italic_n end_ARG is the number of data points. Moreover, because the optimal solution from the Sinkhorn algorithm can be obtained using an iterative algorithm, it can be easily incorporated into deep-learning applications, making optimal transport widely applicable. One limitation of the Sinkhorn algorithm is that it still requires quadratic-time computation, and the final solution depends highly on the regularization parameter.

An alternative approach is the sliced Wasserstein distance (SWD) [36, 24], which solves the optimal transport problem within a projected one-dimensional subspace. The algorithm for Wasserstein distance computation over reals essentially applies sorting as a subroutine; thus, SWD offers O(n¯logn¯)𝑂¯𝑛¯𝑛O(\bar{n}\log\bar{n})italic_O ( over¯ start_ARG italic_n end_ARG roman_log over¯ start_ARG italic_n end_ARG ) computation. SWD’s extensions include the generalized sliced Wasserstein distance for multidimensional cases [25]; the max-sliced Wasserstein distance, which determines the optimal transport-enhancing 1D subspace [33, 13]; and the subspace-robust Wasserstein distance [35].

The 1-Wasserstein distance with a tree metric (also known as the Tree-Wasserstein Distance (TWD)) is a generalization of the sliced Wasserstein distance and total variation [21, 15, 30]. The TWD is also known as the UniFrac distance [32] and is assumed to have a phylogenetic tree beforehand. An important property of TWD is that TWD has an analytical solution for the L1 distance of tree-embedded vectors.

Originally, TWD was studied in theoretical computer science, known as the QuadTree algorithm [21]. This has recently been extended by the ML community to include unbalanced TWD [39, 29], supervised Wasserstein training [41], and tree barycenters [42]. These approaches focus on approximating the 1-Wasserstein distance through tree construction and often utilize constant-edge weights. In terms of approaches that consider non-uniform edge weights, Backurs et al. [2] introduced FlowTree, amalgamating QuadTree and cost matrix methods, outperforming QuadTree. They guaranteed that QuadTree and FlowTree approximate nearest neighbors by employing the 1-Wasserstein distance. Dey and Zhang [14] proposed an L1-embedding for approximating the 1-Wasserstein distance for persistence diagrams. Finally, Yamada et al. [48] proposed a computationally efficient tree weight estimation technique for TWD and empirically demonstrated that TWD can attain comparable performance to the Wasserstein distance, while achieving computational speeds several orders of magnitude faster than linear programming computation of the Wasserstein distance.

Most existing studies on TWD have focused on tree construction [21, 30, 41] and edge weight estimation [48]. Takezawa et al. [42] proposed a Barycenter method based on TWD where the set of simplicial representations are given. Frogner et al. [17] and Toyokuni et al. [43] considered utilizing the Wasserstein distance for multi-label classification. These studies focused on supervised learning by employing softmax as the probability model. In this study, we investigate the Wasserstein distance by employing an SSL framework and evaluate various probability models.

3 Background

3.1 Self-supervised Learning methods

SimCLR [9]: Given n𝑛nitalic_n input vectors {𝒙i}i=1nsuperscriptsubscriptsubscript𝒙𝑖𝑖1𝑛\{{\bm{x}}_{i}\}_{i=1}^{n}{ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, where 𝒙idsubscript𝒙𝑖superscript𝑑{\bm{x}}_{i}\in\mathbb{R}^{d}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, define the data transformation functions 𝒖(1)=ϕ1(𝒙)dsuperscript𝒖1subscriptbold-italic-ϕ1𝒙superscript𝑑{\bm{u}}^{(1)}={\bm{\phi}}_{1}({\bm{x}})\in\mathbb{R}^{d}bold_italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_italic_ϕ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and 𝒖(2)=ϕ2(𝒙)dsuperscript𝒖2subscriptbold-italic-ϕ2𝒙superscript𝑑{\bm{u}}^{(2)}={\bm{\phi}}_{2}({\bm{x}})\in\mathbb{R}^{d}bold_italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = bold_italic_ϕ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. In the context of image applications, 𝒖(1)superscript𝒖1{\bm{u}}^{(1)}bold_italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT and 𝒖(2)superscript𝒖2{\bm{u}}^{(2)}bold_italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT can be understood as two image transformations over a given image: translation, rotation, blurring, etc. The neural network model is denoted as 𝒇𝜽(𝒖)doutsubscript𝒇𝜽𝒖superscriptsubscript𝑑out{\bm{f}}_{\bm{\theta}}({\bm{u}})\in\mathbb{R}^{d_{\text{out}}}bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_u ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where 𝜽𝜽{\bm{\theta}}bold_italic_θ is a learnable parameter.

SimCLR attempts to train the model by learning features such that 𝒛(1)=𝒇𝜽(𝒖(1))superscript𝒛1subscript𝒇𝜽superscript𝒖1{\bm{z}}^{(1)}={\bm{f}}_{\bm{\theta}}({\bm{u}}^{(1)})bold_italic_z start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) and 𝒛(2)=𝒇𝜽(𝒖(2))superscript𝒛2subscript𝒇𝜽superscript𝒖2{\bm{z}}^{(2)}={\bm{f}}_{\bm{\theta}}({\bm{u}}^{(2)})bold_italic_z start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) are close after the feature mapping, while ensuring that both are distant from the feature map of 𝒖superscript𝒖{\bm{u}}^{\prime}bold_italic_u start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, where 𝒖superscript𝒖{\bm{u}}^{\prime}bold_italic_u start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is a negative sample generated from a different input image. To this end, InfoNCE loss [34] is employed in the SimCLR model:

InfoNCE(𝒛i(1),𝒛i(2))subscriptInfoNCEsuperscriptsubscript𝒛𝑖1superscriptsubscript𝒛𝑖2\displaystyle\ell_{\text{InfoNCE}}\big{(}{\bm{z}}_{i}^{(1)},{\bm{z}}_{i}^{(2)}% \big{)}roman_ℓ start_POSTSUBSCRIPT InfoNCE end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) =logexp(sim(𝒛i(1),𝒛i(2))/τ)Z¯,absentsimsuperscriptsubscript𝒛𝑖1superscriptsubscript𝒛𝑖2𝜏¯𝑍\displaystyle=-\log\frac{\exp\Big{(}\text{sim}\big{(}{\bm{z}}_{i}^{(1)},{\bm{z% }}_{i}^{(2)}\big{)}/\tau\Big{)}}{\bar{Z}},= - roman_log divide start_ARG roman_exp ( sim ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) / italic_τ ) end_ARG start_ARG over¯ start_ARG italic_Z end_ARG end_ARG ,

where Z¯=k=12Rδkiexp(sim(𝒛i(1),𝒛~k)/τ)¯𝑍superscriptsubscript𝑘12𝑅subscript𝛿𝑘𝑖simsuperscriptsubscript𝒛𝑖1subscript~𝒛𝑘𝜏\bar{Z}=\sum_{k=1}^{2R}\delta_{k\neq i}\exp(\text{sim}({\bm{z}}_{i}^{(1)},% \tilde{{\bm{z}}}_{k})/\tau)over¯ start_ARG italic_Z end_ARG = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_R end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_k ≠ italic_i end_POSTSUBSCRIPT roman_exp ( sim ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , over~ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) / italic_τ ) is the normaliser, R𝑅Ritalic_R is the batch size and sim(𝒛,𝒛)sim𝒛superscript𝒛\text{sim}({\bm{z}},{\bm{z}}^{\prime})sim ( bold_italic_z , bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) is a similarity function that takes a higher positive value when 𝒛𝒛{\bm{z}}bold_italic_z and 𝒛superscript𝒛{\bm{z}}^{\prime}bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT are similar and a smaller (positive or negative) value when 𝒛𝒛{\bm{z}}bold_italic_z and 𝒛superscript𝒛{\bm{z}}^{\prime}bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT are dissimilar. τ𝜏\tauitalic_τ is the temperature parameter, and δkisubscript𝛿𝑘𝑖\delta_{k\neq i}italic_δ start_POSTSUBSCRIPT italic_k ≠ italic_i end_POSTSUBSCRIPT is a delta function that takes a value of 1 when ki𝑘𝑖k\neq iitalic_k ≠ italic_i and 0 otherwise.

In SimCLR, the parameters are learned by minimizing the InfoNCE loss. Indeed, the numerator of the InfoNCE loss is proportional to the similarity between 𝒛isubscript𝒛𝑖{\bm{z}}_{i}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝒛jsubscript𝒛𝑗{\bm{z}}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. The denominator is a log-sum exp function and a softmax function. Because we attempt to minimize the maximum similarity between input 𝒛isubscript𝒛𝑖{\bm{z}}_{i}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and its negative samples, we can make 𝒛isubscript𝒛𝑖{\bm{z}}_{i}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and its negative samples dissimilar via the second term.

𝜽^:=argmin𝜽i=1nInfoNCE(𝒇𝜽(𝒖i(1)),𝒇𝜽(𝒖i(2))).assign^𝜽subscriptargmin𝜽superscriptsubscript𝑖1𝑛subscriptInfoNCEsubscript𝒇𝜽superscriptsubscript𝒖𝑖1subscript𝒇𝜽superscriptsubscript𝒖𝑖2\displaystyle\widehat{{\bm{\theta}}}:=\mathop{\mathrm{argmin\,}}_{{\bm{\theta}% }}\sum_{i=1}^{n}\ell_{\text{InfoNCE}}\big{(}{\bm{f}}_{\bm{\theta}}({\bm{u}}_{i% }^{(1)}),{\bm{f}}_{\bm{\theta}}({\bm{u}}_{i}^{(2)})\big{)}.over^ start_ARG bold_italic_θ end_ARG := start_BIGOP roman_argmin end_BIGOP start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT InfoNCE end_POSTSUBSCRIPT ( bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) , bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) ) .

SimSiam [8]: SimSiam is a non-contrastive learning method; it does not use negative sampling to prevent mode collapse. In place of negative sampling, SimSiam employs a stop-gradient method. Specifically, the loss function is given by

LSimSiam(𝜽)=12L1(𝜽)+12L2(𝜽),subscript𝐿SimSiam𝜽12subscript𝐿1𝜽12subscript𝐿2𝜽\displaystyle L_{\textnormal{SimSiam}}({\bm{\theta}})=\frac{1}{2}L_{1}({\bm{% \theta}})+\frac{1}{2}L_{2}({\bm{\theta}}),italic_L start_POSTSUBSCRIPT SimSiam end_POSTSUBSCRIPT ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_θ ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_italic_θ ) ,
L1(𝜽)=1ni=1n𝒉(𝒛i)𝒛¯i𝒉(𝒛i)2𝒛¯i2,L2(𝜽)=1ni=1n𝒛¯i𝒉(𝒛i)𝒛¯i2𝒉(𝒛i)2,formulae-sequencesubscript𝐿1𝜽1𝑛superscriptsubscript𝑖1𝑛𝒉superscriptsubscript𝒛𝑖topsubscriptsuperscript¯𝒛𝑖subscriptnorm𝒉subscript𝒛𝑖2subscriptnormsubscriptsuperscript¯𝒛𝑖2subscript𝐿2𝜽1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript¯𝒛𝑖top𝒉subscriptsuperscript𝒛𝑖subscriptnormsubscript¯𝒛𝑖2subscriptnorm𝒉subscriptsuperscript𝒛𝑖2\displaystyle L_{1}({\bm{\theta}})\!=\!-\frac{1}{n}\sum_{i=1}^{n}\!\frac{{\bm{% h}}({\bm{z}}_{i})^{\top}\bar{{\bm{z}}}^{\prime}_{i}}{\|{\bm{h}}({\bm{z}}_{i})% \|_{2}\|\bar{{\bm{z}}}^{\prime}_{i}\|_{2}},L_{2}({\bm{\theta}})\!=\!-\frac{1}{% n}\sum_{i=1}^{n}\!\frac{\bar{{\bm{z}}}_{i}^{\top}{\bm{h}}({\bm{z}}^{\prime}_{i% })}{\|\bar{{\bm{z}}}_{i}\|_{2}\|{\bm{h}}({{\bm{z}}}^{\prime}_{i})\|_{2}},italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_θ ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG bold_italic_h ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over¯ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∥ bold_italic_h ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ over¯ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_italic_θ ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_h ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ bold_italic_h ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ,

where 𝒉()𝒉{\bm{h}}(\cdot)bold_italic_h ( ⋅ ) is the MLP head, 𝒛isubscript𝒛𝑖{\bm{z}}_{i}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a latent variable, and 𝒛¯i=StopGradient(𝒛i)subscript¯𝒛𝑖StopGradientsubscript𝒛𝑖\bar{{\bm{z}}}_{i}=\textnormal{StopGradient}({\bm{z}}_{i})over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = StopGradient ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is a latent variable with a stop gradient.

3.2 p𝑝pitalic_p-Wasserstein distance

The p𝑝pitalic_p-Wasserstein distance between two discrete measures, μ=i=1n¯aiδ𝒙i𝜇superscriptsubscript𝑖1¯𝑛subscript𝑎𝑖subscript𝛿subscript𝒙𝑖\mu=\sum_{i=1}^{\bar{n}}a_{i}\delta_{{\bm{x}}_{i}}italic_μ = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_n end_ARG end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT and μ=j=1m¯ajδ𝒚jsuperscript𝜇superscriptsubscript𝑗1¯𝑚subscriptsuperscript𝑎𝑗subscript𝛿subscript𝒚𝑗\mu^{\prime}=\sum_{j=1}^{\bar{m}}a^{\prime}_{j}\delta_{{\bm{y}}_{j}}italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_m end_ARG end_POSTSUPERSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT is given by

𝒲p(μ,μ)=(min𝚷U(μ,μ)i=1n¯j=1m¯πijd(𝒙i,𝒚j)p)1/p,subscript𝒲𝑝𝜇superscript𝜇superscriptsubscript𝚷𝑈𝜇superscript𝜇superscriptsubscript𝑖1¯𝑛superscriptsubscript𝑗1¯𝑚subscript𝜋𝑖𝑗𝑑superscriptsubscript𝒙𝑖subscript𝒚𝑗𝑝1𝑝\displaystyle{\mathcal{W}}_{p}(\mu,\mu^{\prime})=\left(\min_{{\bm{\Pi}}\in U(% \mu,\mu^{\prime})}\sum_{i=1}^{\bar{n}}\sum_{j=1}^{\bar{m}}\pi_{ij}d({\bm{x}}_{% i},{\bm{y}}_{j})^{p}\right)^{1/p},caligraphic_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = ( roman_min start_POSTSUBSCRIPT bold_Π ∈ italic_U ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_n end_ARG end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_m end_ARG end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_d ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / italic_p end_POSTSUPERSCRIPT ,

where U(μ,μ)𝑈𝜇superscript𝜇U(\mu,\mu^{\prime})italic_U ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) denotes the set of transport plans and U(μ,μ)={𝚷+n¯×m¯:𝚷𝟏m¯=𝒂,𝚷𝟏n¯=𝒂}𝑈𝜇superscript𝜇conditional-set𝚷superscriptsubscript¯𝑛¯𝑚formulae-sequence𝚷subscript1¯𝑚𝒂superscript𝚷topsubscript1¯𝑛superscript𝒂U(\mu,\mu^{\prime})=\{{\bm{\Pi}}\in\mathbb{R}_{+}^{\bar{n}\times\bar{m}}:{\bm{% \Pi}}{\bm{1}}_{\bar{m}}={\bm{a}},{\bm{\Pi}}^{\top}{\bm{1}}_{\bar{n}}={\bm{a}}^% {\prime}\}italic_U ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = { bold_Π ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_n end_ARG × over¯ start_ARG italic_m end_ARG end_POSTSUPERSCRIPT : bold_Π bold_1 start_POSTSUBSCRIPT over¯ start_ARG italic_m end_ARG end_POSTSUBSCRIPT = bold_italic_a , bold_Π start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT over¯ start_ARG italic_n end_ARG end_POSTSUBSCRIPT = bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT }. The Wasserstein distance can be computed using a linear program. However, because this includes an optimization problem, the computation of Wasserstein distance for each iteration is computationally expensive.

An alternative approach is entropic regularization [11]. If we consider the 1-Wasserstein distance, the entropic regularized variant is given as

min𝚷U(μ,μ)i=1n¯j=1m¯πijd(𝒙i,𝒚j)+λi=1n¯j=1m¯πij(log(πij)1).subscript𝚷𝑈𝜇superscript𝜇superscriptsubscript𝑖1¯𝑛superscriptsubscript𝑗1¯𝑚subscript𝜋𝑖𝑗𝑑subscript𝒙𝑖subscript𝒚𝑗𝜆superscriptsubscript𝑖1¯𝑛superscriptsubscript𝑗1¯𝑚subscript𝜋𝑖𝑗subscript𝜋𝑖𝑗1\displaystyle\min_{{\bm{\Pi}}\in U(\mu,\mu^{\prime})}\sum_{i=1}^{\bar{n}}\!% \sum_{j=1}^{\bar{m}}\pi_{ij}d({\bm{x}}_{i},{\bm{y}}_{j})\!+\!\lambda\!\sum_{i=% 1}^{\bar{n}}\!\sum_{j=1}^{\bar{m}}\pi_{ij}(\log(\pi_{ij})\!-\!1).roman_min start_POSTSUBSCRIPT bold_Π ∈ italic_U ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_n end_ARG end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_m end_ARG end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_d ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_λ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_n end_ARG end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over¯ start_ARG italic_m end_ARG end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( roman_log ( italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) - 1 ) .

This optimization problem can be solved efficiently using the Sinkhorn algorithm [11] at a computational cost of O(n¯m¯)𝑂¯𝑛¯𝑚O(\bar{n}\bar{m})italic_O ( over¯ start_ARG italic_n end_ARG over¯ start_ARG italic_m end_ARG ). More importantly, the solution of the Sinkhorn algorithm is given as a series of matrix multiplications. The Sinkhorn algorithm is widely employed in deep learning algorithms.

Refer to caption
Figure 1: Left tree corresponds to the total variation if we set the weight as wi=12,isubscript𝑤𝑖12for-all𝑖w_{i}=\frac{1}{2},\forall iitalic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG , ∀ italic_i. Right tree is a ClusterTree (2 class).
Refer to caption
Figure 2: Tree for sliced Wasserstein distance for Nleaf=3subscript𝑁leaf3N_{\text{leaf}}=3italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT = 3. The left figure is a chain and the right figure is the tree representation with internal nodes for the chain (w4=w5=w6=0subscript𝑤4subscript𝑤5subscript𝑤60w_{4}=w_{5}=w_{6}=0italic_w start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT = 0).

3.3 1-Wasserstein distance with tree metric (Tree-Wasserstein Distance)

Another 1-Wasserstein distance is based on trees [21, 30]. The 1-Wasserstein distance with the tree metric is defined as the L1 distance between two probability distributions μ=iaiδ𝒙i𝜇subscript𝑖subscript𝑎𝑖subscript𝛿subscript𝒙𝑖\mu=\sum_{i}a_{i}\delta_{{\bm{x}}_{i}}italic_μ = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT and μ=jajδ𝒚jsuperscript𝜇subscript𝑗subscriptsuperscript𝑎𝑗subscript𝛿subscript𝒚𝑗\mu^{\prime}=\sum_{j}a^{\prime}_{j}\delta_{{\bm{y}}_{j}}italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT:

W𝒯(μ,μ)subscript𝑊𝒯𝜇superscript𝜇\displaystyle W_{\mathcal{T}}(\mu,\mu^{\prime})italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) =diag(𝒘)𝑩𝒂diag(𝒘)𝑩𝒂1,absentsubscriptnormdiag𝒘𝑩𝒂diag𝒘𝑩superscript𝒂1\displaystyle=\|\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}-\text{diag}({\bm{w% }}){\bm{B}}{\bm{a}}^{\prime}\|_{1},= ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a - diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ,

where 𝑩{0,1}Nnode×Nleaf𝑩superscript01subscript𝑁nodesubscript𝑁leaf{\bm{B}}\in\{0,1\}^{N_{\text{node}}\times N_{\text{leaf}}}bold_italic_B ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT node end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a tree parameter, [𝑩]i,j=1subscriptdelimited-[]𝑩𝑖𝑗1[{\bm{B}}]_{i,j}=1[ bold_italic_B ] start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = 1 if node i𝑖iitalic_i is the ancestor node of leaf node j𝑗jitalic_j and zero otherwise, Nnodesubscript𝑁nodeN_{\text{node}}italic_N start_POSTSUBSCRIPT node end_POSTSUBSCRIPT is the total number of nodes of a tree, Nleafsubscript𝑁leafN_{\text{leaf}}italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT is the number of leaf nodes, and 𝒘+Nnode𝒘superscriptsubscriptsubscript𝑁node{\bm{w}}\in\mathbb{R}_{+}^{N_{\text{node}}}bold_italic_w ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT node end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the edge weight.

For illustration, we provide two examples to demonstrate the B𝐵Bitalic_B matrix by considering a tree with a depth of one and a ClusterTree, as shown in Figure 1. If all edge weights w1=w2==wN=12subscript𝑤1subscript𝑤2subscript𝑤𝑁12w_{1}=w_{2}=\ldots=w_{N}=\frac{1}{2}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = … = italic_w start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG in the left panel of Figure 1, then the 𝑩𝑩{\bm{B}}bold_italic_B matrix is given as 𝑩=𝑰𝑩𝑰{\bm{B}}={\bm{I}}bold_italic_B = bold_italic_I. By substituting this result into the TWD, we obtain

W𝒯(μ,μ)=12𝒂𝒂1=𝒂𝒂TV.subscript𝑊𝒯𝜇superscript𝜇12subscriptnorm𝒂superscript𝒂1subscriptnorm𝒂superscript𝒂TV\displaystyle W_{\mathcal{T}}(\mu,\mu^{\prime})=\frac{1}{2}\|{\bm{a}}-{\bm{a}}% ^{\prime}\|_{1}=\|{\bm{a}}-{\bm{a}}^{\prime}\|_{\textnormal{TV}}.italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT TV end_POSTSUBSCRIPT .

Thus, the total variation is a special case of TWD. In this setting, the shortest-path distance in the tree corresponds to the Hamming distance. Note that Raginsky et al. [37] also assert that the 1-Wasserstein distance with the Hamming metric d(𝒙,𝒚)=δ𝒙𝒚𝑑𝒙𝒚subscript𝛿𝒙𝒚d({\bm{x}},{\bm{y}})=\delta_{{\bm{x}}\neq{\bm{y}}}italic_d ( bold_italic_x , bold_italic_y ) = italic_δ start_POSTSUBSCRIPT bold_italic_x ≠ bold_italic_y end_POSTSUBSCRIPT is equivalent to the total variation (Proposition 3.4.1).

The key advantage of the tree-based approach is that the Wasserstein distance is written in closed form, which is computationally efficient. A chain is included as a special case in the tree. Thus, the widely employed sliced Wasserstein distance is also included as a special case of TWD (Figure 2). Moreover, it has been empirically reported that TWD- and Sinkhorn-based approaches perform similarly in multilabel classification tasks [43].

4 SSL with 1-Wasserstein Distance

In this section, we first formulate SSL using TWD. We then introduce ArcFace-based probability models and Jeffrey divergence regularization.

4.1 SimCLR with Tree Wasserstein Distance

Let 𝒂𝒂{\bm{a}}bold_italic_a and 𝒂superscript𝒂{\bm{a}}^{\prime}bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT be the simplicial embedding vectors of 𝒙𝒙{\bm{x}}bold_italic_x and 𝒙superscript𝒙{\bm{x}}^{\prime}bold_italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT (i.e., 𝟏𝒂=1superscript1top𝒂1{\bm{1}}^{\top}{\bm{a}}=1bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_a = 1 and 𝟏𝒂superscript1topsuperscript𝒂{\bm{1}}^{\top}{\bm{a}}^{\prime}bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT) with μ=jajδ𝒆j𝜇subscript𝑗subscript𝑎𝑗subscript𝛿subscript𝒆𝑗\mu=\sum_{j}a_{j}\delta_{{\bm{e}}_{j}}italic_μ = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT and μ=jajδ𝒆jsuperscript𝜇subscript𝑗subscriptsuperscript𝑎𝑗subscript𝛿subscript𝒆𝑗\mu^{\prime}=\sum_{j}a^{\prime}_{j}\delta_{{\bm{e}}_{j}}italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT, respectively. Here, 𝒆jsubscript𝒆𝑗{\bm{e}}_{j}bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is the virtual embedding corresponding to ajsubscript𝑎𝑗a_{j}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT or ajsubscriptsuperscript𝑎𝑗a^{\prime}_{j}italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. 𝒆𝒆{\bm{e}}bold_italic_e is assumed unavailable in the problem setup. The main idea of this paper is to adopt the negative Wasserstein distance between μ𝜇\muitalic_μ and μsuperscript𝜇\mu^{\prime}italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT as the similarity score for SimCLR.

sim(μ,μ)sim𝜇superscript𝜇\displaystyle\text{sim}(\mu,\mu^{\prime})sim ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) =W𝒯(μ,μ).absentsubscript𝑊𝒯𝜇superscript𝜇\displaystyle=-W_{{\mathcal{T}}}(\mu,\mu^{\prime}).= - italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) .

We assume that 𝑩𝑩{\bm{B}}bold_italic_B and 𝒘𝒘{\bm{w}}bold_italic_w are given; that is, both the tree structure and weights are known. In particular, we consider the trees shown in Figure 1.

Following the original design of the InfoNCE loss and by substituting the similarity score given by the negative Wasserstein distance, we obtain the following simplified loss function:

𝜽^^𝜽\displaystyle\widehat{{\bm{\theta}}}over^ start_ARG bold_italic_θ end_ARG :=argmin𝜽i=1n(W𝒯(μi(1),μi(2))/τ+logk=12Nδkiexp(W𝒯(μi(1),μk(2))/τ)),assignabsentsubscriptargmin𝜽superscriptsubscript𝑖1𝑛subscript𝑊𝒯superscriptsubscript𝜇𝑖1superscriptsubscript𝜇𝑖2𝜏superscriptsubscript𝑘12𝑁subscript𝛿𝑘𝑖subscript𝑊𝒯superscriptsubscript𝜇𝑖1superscriptsubscript𝜇𝑘2𝜏\displaystyle:=\mathop{\mathrm{argmin\,}}_{{\bm{\theta}}}\sum_{i=1}^{n}\left(W% _{\mathcal{T}}(\mu_{i}^{(1)},\mu_{i}^{(2)})/\tau+\log\sum_{k=1}^{2N}\delta_{k% \neq i}\exp\left(-W_{\mathcal{T}}(\mu_{i}^{(1)},\mu_{k}^{(2)})/\tau\right)% \right),:= start_BIGOP roman_argmin end_BIGOP start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) / italic_τ + roman_log ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_k ≠ italic_i end_POSTSUBSCRIPT roman_exp ( - italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) / italic_τ ) ) ,

where τ>0𝜏0\tau>0italic_τ > 0 is the temperature parameter for the InfoNCE loss 111Although we mainly focus on the InfoNCE loss, the proposed negative Wasserstein distance as a measure of similarity can be used in other contrastive losses as well, e.g., the Barlow Twins..

4.2 SimSiam with Tree Wasserstein Distance

Here, we consider a combination of SimSiam and TWD. The loss function of the proposed approach is expressed as

LTWDSimSiam(𝜽)=12L1(𝜽)+12L2(𝜽),subscript𝐿TWDSimSiam𝜽12subscript𝐿1𝜽12subscript𝐿2𝜽\displaystyle L_{\textnormal{TWDSimSiam}}({\bm{\theta}})=\frac{1}{2}L_{1}({\bm% {\theta}})+\frac{1}{2}L_{2}({\bm{\theta}}),italic_L start_POSTSUBSCRIPT TWDSimSiam end_POSTSUBSCRIPT ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_θ ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_italic_θ ) ,
L1(𝜽)=1ni=1nW𝒯(μi(1),μ¯i(2)),L2(𝜽)=1ni=1nW𝒯(μ¯i(1),μi(2)).formulae-sequencesubscript𝐿1𝜽1𝑛superscriptsubscript𝑖1𝑛subscript𝑊𝒯superscriptsubscript𝜇𝑖1superscriptsubscript¯𝜇𝑖2subscript𝐿2𝜽1𝑛superscriptsubscript𝑖1𝑛subscript𝑊𝒯superscriptsubscript¯𝜇𝑖1superscriptsubscript𝜇𝑖2\displaystyle L_{1}({\bm{\theta}})\!=\!\frac{1}{n}\sum_{i=1}^{n}\!W_{\mathcal{% T}}\big{(}\mu_{i}^{(1)},\bar{\mu}_{i}^{(2)}\big{)},L_{2}({\bm{\theta}})\!=\!% \frac{1}{n}\sum_{i=1}^{n}\!W_{\mathcal{T}}\big{(}\bar{\mu}_{i}^{(1)},\mu_{i}^{% (2)}\big{)}.italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) , italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) .

The distinction to the original SimSiam is that our formulation employs the Wasserstein distance, whereas the original formulation uses cosine similarity.

4.3 Robust Variant of Tree Wasserstein Distance

In our setup, it is difficult to estimate the tree structure 𝑩𝑩{\bm{B}}bold_italic_B and edge weight 𝒘𝒘{\bm{w}}bold_italic_w because the embedding vectors 𝒆1,𝒆2,,𝒆doutsubscript𝒆1subscript𝒆2subscript𝒆subscript𝑑out{\bm{e}}_{1},{\bm{e}}_{2},\ldots,{\bm{e}}_{d_{\text{out}}}bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_e start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUBSCRIPT are unavailable. To address this problem, we consider a robust estimation of the Wasserstein distance, such as the subspace-robust Wasserstein distance (SRWD) [35], for TWD. The key idea of SRWD is to solve an optimal transport problem in a subspace in which the distance is maximized. In the TWD case, we can consider solving the optimal transport problem for the maximum shortest-path distance. Specifically, for a given 𝑩𝑩{\bm{B}}bold_italic_B, we propose the robust TWD (RTWD) as follows:

RTWD(μ,μ)=12min𝚷U(μ,μ)max𝒘i=1Nleafsj=1Nleafsπijd𝒯(𝒆i,𝒆j),RTWD𝜇superscript𝜇12subscript𝚷𝑈𝜇superscript𝜇subscript𝒘superscriptsubscript𝑖1subscript𝑁leafssuperscriptsubscript𝑗1subscript𝑁leafssubscript𝜋𝑖𝑗subscript𝑑𝒯subscript𝒆𝑖subscript𝒆𝑗\displaystyle\textnormal{RTWD}(\mu,\mu^{\prime})=\frac{1}{2}\min_{{\bm{\Pi}}% \in U(\mu,\mu^{\prime})}~{}\!\max_{{\bm{w}}\in{\mathcal{B}}}\!\sum_{i=1}^{N_{% \textnormal{leafs}}}\sum_{j=1}^{N_{\textnormal{leafs}}}\pi_{ij}d_{\mathcal{T}}% ({\bm{e}}_{i},{\bm{e}}_{j}),RTWD ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_min start_POSTSUBSCRIPT bold_Π ∈ italic_U ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT bold_italic_w ∈ caligraphic_B end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ,

where ={𝒘+Nleaf:𝑩𝒘=𝟏and𝒘𝟎}conditional-set𝒘superscriptsubscriptsubscript𝑁leafsuperscript𝑩top𝒘1and𝒘0{\mathcal{B}}=\{{\bm{w}}\in\mathbb{R}_{+}^{N_{\textnormal{leaf}}}:{\bm{B}}^{% \top}{\bm{w}}={\bm{1}}~{}\textnormal{and}~{}{\bm{w}}\geq{\bm{0}}\}caligraphic_B = { bold_italic_w ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT end_POSTSUPERSCRIPT : bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 and bold_italic_w ≥ bold_0 }, d𝒯(𝒆i,𝒆j)subscript𝑑𝒯subscript𝒆𝑖subscript𝒆𝑗d_{\mathcal{T}}({\bm{e}}_{i},{\bm{e}}_{j})italic_d start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is the shortest-path distance between 𝒆isubscript𝒆𝑖{\bm{e}}_{i}bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝒆jsubscript𝒆𝑗{\bm{e}}_{j}bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, and 𝒆isubscript𝒆𝑖{\bm{e}}_{i}bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝒆jsubscript𝒆𝑗{\bm{e}}_{j}bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are embedded in a tree 𝒯𝒯{\mathcal{T}}caligraphic_T. This constraint implies that the weights of the ancestor node of leaf node j𝑗jitalic_j are non-negative and sum to one.

Proposition 1

The robust variant of TWD (RTWD) is equivalent to total variation:

RTWD(μ,μ)=𝒂𝒂TV,RTWD𝜇superscript𝜇subscriptnorm𝒂superscript𝒂TV\displaystyle\textnormal{RTWD}(\mu,\mu^{\prime})=\|{\bm{a}}-{\bm{a}}^{\prime}% \|_{\textnormal{TV}},RTWD ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT TV end_POSTSUBSCRIPT ,

where 𝐚𝐚TV=12𝐚𝐚1subscriptnorm𝐚superscript𝐚normal-′TV12subscriptnorm𝐚superscript𝐚normal-′1\|{\bm{a}}-{\bm{a}}^{\prime}\|_{\textnormal{TV}}=\frac{1}{2}\|{\bm{a}}-{\bm{a}% }^{\prime}\|_{1}∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT TV end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT denotes the total variation.

The proof is provided in Appendix. Based on this proposition, RTWD is equivalent to the total variation and does not depend on the tree structure 𝑩𝑩{\bm{B}}bold_italic_B. That is, if we do not have prior information about the tree structure, using the total variation is a reasonable choice.

4.4 Probability models

In this section, we discuss several choices of probability models for InfoNCE loss and SimSiam loss.

Softmax: The softmax function for simplicial representation is given by

𝒂𝜽(𝒙)=Softmax(𝒇𝜽(𝒙)),subscript𝒂𝜽𝒙Softmaxsubscript𝒇𝜽𝒙\displaystyle{\bm{a}}_{\bm{\theta}}({\bm{x}})=\text{Softmax}({\bm{f}}_{\bm{% \theta}}({\bm{x}})),bold_italic_a start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) = Softmax ( bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) ) ,

where 𝒇𝜽(𝒙)subscript𝒇𝜽𝒙{\bm{f}}_{\bm{\theta}}({\bm{x}})bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) is a neural network model.

Simplicial Embedding: Lavoie et al. [28] proposed a simple yet efficient simplicial embedding method. Assume that the output dimensionality of a neural network model is doutsubscript𝑑outd_{\text{out}}italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT. Then, SEM applies the softmax function to each V𝑉Vitalic_V-dimensional vector of 𝒇𝜽(𝒙)subscript𝒇𝜽𝒙{\bm{f}}_{\bm{\theta}}({\bm{x}})bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ), where we have L=dout/V𝐿subscript𝑑out𝑉L=d_{\text{out}}/Vitalic_L = italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT / italic_V probability vectors. The \ellroman_ℓth softmax function is thus defined as follows:

𝒂𝜽(𝒙)=[𝒂𝜽(1)(𝒙),𝒂𝜽(2)(𝒙),,𝒂𝜽(L)(𝒙)]subscript𝒂𝜽𝒙superscriptsubscriptsuperscript𝒂1𝜽superscript𝒙topsubscriptsuperscript𝒂2𝜽superscript𝒙topsubscriptsuperscript𝒂𝐿𝜽superscript𝒙toptop\displaystyle{\bm{a}}_{\bm{\theta}}({\bm{x}})=\Big{[}{\bm{a}}^{(1)}_{\bm{% \theta}}({\bm{x}})^{\top},{\bm{a}}^{(2)}_{\bm{\theta}}({\bm{x}})^{\top},\ldots% ,{\bm{a}}^{(L)}_{\bm{\theta}}({\bm{x}})^{\top}\Big{]}^{\top}bold_italic_a start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) = [ bold_italic_a start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_italic_a start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , … , bold_italic_a start_POSTSUPERSCRIPT ( italic_L ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
with𝒂𝜽()(𝒙)=Softmax(𝒇𝜽()(𝒙))/L,withsubscriptsuperscript𝒂𝜽𝒙Softmaxsubscriptsuperscript𝒇𝜽𝒙𝐿\displaystyle\text{with}~{}~{}{\bm{a}}^{(\ell)}_{\bm{\theta}}({\bm{x}})=\text{% Softmax}\big{(}{\bm{f}}^{(\ell)}_{\bm{\theta}}({\bm{x}})\big{)}/L,with bold_italic_a start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) = Softmax ( bold_italic_f start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) ) / italic_L ,

where 𝒇𝜽()(𝒙))V{\bm{f}}^{(\ell)}_{\bm{\theta}}({\bm{x}}))\in\mathbb{R}^{V}bold_italic_f start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT is the \ellroman_ℓ-th block of a neural network model. We normalize the softmax function by L𝐿Litalic_L because 𝒂𝜽(𝒙)subscript𝒂𝜽𝒙{\bm{a}}_{\bm{\theta}}({\bm{x}})bold_italic_a start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) must satisfy the sum-to-one constraint.

ArcFace model (AF): In comparison to SEM, we propose to employ the ArcFace probability model [12].

𝒂𝜽(𝒙)=Softmax(𝑲𝒇𝜽(𝒙)/η),subscript𝒂𝜽𝒙Softmaxsuperscript𝑲topsubscript𝒇𝜽𝒙𝜂\displaystyle{\bm{a}}_{\bm{\theta}}({\bm{x}})=\textnormal{Softmax}\big{(}{\bm{% K}}^{\top}{\bm{f}}_{\bm{\theta}}({\bm{x}})/\eta\big{)},bold_italic_a start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) = Softmax ( bold_italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) / italic_η ) ,

where 𝑲=[𝒌1,𝒌2,,𝒌dout]dout×dprob𝑲subscript𝒌1subscript𝒌2subscript𝒌subscript𝑑outsuperscriptsubscript𝑑outsubscript𝑑prob{\bm{K}}=[{\bm{k}}_{1},{\bm{k}}_{2},\ldots,{\bm{k}}_{d_{\text{out}}}]\in% \mathbb{R}^{d_{\text{out}}\times d_{\text{prob}}}bold_italic_K = [ bold_italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_k start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT prob end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a learning parameter, 𝒇𝜽(𝒙)subscript𝒇𝜽𝒙{\bm{f}}_{\bm{\theta}}({\bm{x}})bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) is the normalized output of a model (𝒇𝜽(𝒙)𝒇𝜽(𝒙)=1subscript𝒇𝜽superscript𝒙topsubscript𝒇𝜽𝒙1{\bm{f}}_{\bm{\theta}}({\bm{x}})^{\top}{\bm{f}}_{\bm{\theta}}({\bm{x}})=1bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) = 1), and η𝜂\etaitalic_η is the temperature parameter. Note that AF has a structure similar to that of transformers [3, 46]. The key difference from the original notion of attention in transformers is the normalization of the key matrix 𝑲𝑲{\bm{K}}bold_italic_K and query vector 𝒇𝜽(𝒙)subscript𝒇𝜽𝒙{\bm{f}}_{\bm{\theta}}({\bm{x}})bold_italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x ).

AF with Positional Encoding: To the AF model, one can add one more linear layer and then apply the softmax function; then the output is similar to the standard softmax function. Here, we propose replacing the key matrix with a normalized positional encoding matrix (𝒌i𝒌i=1,isuperscriptsubscript𝒌𝑖topsubscript𝒌𝑖1for-all𝑖{\bm{k}}_{i}^{\top}{\bm{k}}_{i}=1,\forall ibold_italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 , ∀ italic_i):

𝒌isubscript𝒌𝑖\displaystyle{\bm{k}}_{i}bold_italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =𝒌¯i/𝒌¯i2,absentsubscript¯𝒌𝑖subscriptnormsubscript¯𝒌𝑖2\displaystyle={\bar{{\bm{k}}}_{i}}/{\|\bar{{\bm{k}}}_{i}\|_{2}},= over¯ start_ARG bold_italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / ∥ over¯ start_ARG bold_italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ,

where 𝒌¯i(2j)=sin(i/100002j/dout)superscriptsubscript¯𝒌𝑖2𝑗𝑖superscript100002𝑗subscript𝑑out\bar{{\bm{k}}}_{i}^{(2j)}=\sin(i/10000^{2j/d_{\text{out}}})over¯ start_ARG bold_italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 italic_j ) end_POSTSUPERSCRIPT = roman_sin ( italic_i / 10000 start_POSTSUPERSCRIPT 2 italic_j / italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) and 𝒌¯i(2j+1)=cos(i/100002j/dout)superscriptsubscript¯𝒌𝑖2𝑗1𝑖superscript100002𝑗subscript𝑑out\bar{{\bm{k}}}_{i}^{(2j+1)}=\cos(i/10000^{2j/d_{\text{out}}})over¯ start_ARG bold_italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 italic_j + 1 ) end_POSTSUPERSCRIPT = roman_cos ( italic_i / 10000 start_POSTSUPERSCRIPT 2 italic_j / italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ).

AF with Discrete Cosine Transform Matrix: Another natural approach would be to utilize an orthogonal matrix as 𝑲𝑲{\bm{K}}bold_italic_K. Therefore, we propose adopting a discrete cosine transform (DCT) matrix as 𝑲𝑲{\bm{K}}bold_italic_K. The DCT matrix is expressed as follows:

𝒌i(j)={1/dout(i=0)2doutcosπ(2j+1)i2dout(1idout).superscriptsubscript𝒌𝑖𝑗cases1subscript𝑑out𝑖02subscript𝑑out𝜋2𝑗1𝑖2subscript𝑑out1𝑖subscript𝑑out\displaystyle{\bm{k}}_{i}^{(j)}=\left\{\begin{array}[]{ll}1/\sqrt{d_{\text{out% }}}&(i=0)\\ \sqrt{\frac{2}{d_{\text{out}}}}\cos\frac{\pi(2j+1)i}{2d_{\text{out}}}&(1\leq i% \leq d_{\text{out}})\end{array}\right..bold_italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT = { start_ARRAY start_ROW start_CELL 1 / square-root start_ARG italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_ARG end_CELL start_CELL ( italic_i = 0 ) end_CELL end_ROW start_ROW start_CELL square-root start_ARG divide start_ARG 2 end_ARG start_ARG italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_ARG end_ARG roman_cos divide start_ARG italic_π ( 2 italic_j + 1 ) italic_i end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_ARG end_CELL start_CELL ( 1 ≤ italic_i ≤ italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARRAY .

One of the contributions of this study is the finding that combining positional encoding and the DCT matrix with the ArcFace model significantly boosts performance, whereas the standard ArcFace model without these additions performs similarly to the softmax function.

4.5 Jeffrey-Divergence Regularization

We empirically observed that optimizing the loss function described above is extremely challenging. In particular, the L1 distance cannot be differentiated at 0. Figure 3(b) illustrates the learning curve for standard optimization using the softmax function model.

To stabilize optimization, we propose including the Jeffrey divergence (JD) as a regularization term. JD is an upper bound of the square of the 1-Wasserstein distance.

Proposition 2

For 𝐁𝐰=𝟏superscript𝐁top𝐰1{\bm{B}}^{\top}{\bm{w}}={\bm{1}}bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 and probability vectors 𝐚isubscript𝐚𝑖{\bm{a}}_{i}bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝐚jsubscript𝐚𝑗{\bm{a}}_{j}bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, we have

W𝒯2(μi,μj)JD(diag(𝒘)𝑩𝒂idiag(𝒘)𝑩𝒂j),subscriptsuperscript𝑊2𝒯subscript𝜇𝑖subscript𝜇𝑗JDconditionaldiag𝒘𝑩subscript𝒂𝑖diag𝒘𝑩subscript𝒂𝑗\displaystyle W^{2}_{{\mathcal{T}}}(\mu_{i},\mu_{j})\leq\textnormal{JD}(% \textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i}\|\textnormal{diag}({\bm{w}}){% \bm{B}}{\bm{a}}_{j}),italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≤ JD ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ,

where JD(diag(𝐰)𝐁𝐚idiag(𝐰)𝐁𝐚j)=KL(diag(𝐰)𝐁𝐚idiag(𝐰)𝐁𝐚j)+KL(diag(𝐰)𝐁𝐚jdiag(𝐰)𝐁𝐚i)JDconditionaldiag𝐰𝐁subscript𝐚𝑖diag𝐰𝐁subscript𝐚𝑗KLconditionaldiag𝐰𝐁subscript𝐚𝑖diag𝐰𝐁subscript𝐚𝑗KLconditionaldiag𝐰𝐁subscript𝐚𝑗diag𝐰𝐁subscript𝐚𝑖\textnormal{JD}(\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i}\|\textnormal{% diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j})={\textnormal{KL}(\textnormal{diag}({\bm{w% }}){\bm{B}}{\bm{a}}_{i}\|\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j})}+{% \textnormal{KL}(\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j}\|\textnormal{% diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i})}JD ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is a Jeffrey divergence.

This result indicates that minimizing the symmetric KL divergence (i.e., Jeffrey divergence) can minimize the tree-Wasserstein distance. Because the Jeffrey divergence is smooth, the computation of the gradient of the upper bound is easier. For presentation, we denote W𝒯(μ(1),μ(2))=W𝒯(𝒂(1),𝒂(2))subscript𝑊𝒯superscript𝜇1superscript𝜇2subscript𝑊𝒯superscript𝒂1superscript𝒂2W_{\mathcal{T}}(\mu^{(1)},\mu^{(2)})=W_{\mathcal{T}}({\bm{a}}^{(1)},{\bm{a}}^{% (2)})italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_μ start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) = italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_a start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_italic_a start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ).

Note that Frogner et al. [17] considered a multilabel classification problem utilizing the regularized Wasserstein loss. They proposed utilizing Kullback–Leibler divergence-based regularization to stabilize training. We derive the Jeffrey divergence from the TWD, and JD regularisation includes a simple KL divergence-based regularization as a special case. Moreover, we propose employing JD regularization for SSL frameworks, which have not been extensively studied.

Refer to caption
(a) Loss of Cosine + Real.
Refer to caption
(b) Loss of TV + Softmax.
Refer to caption
(c) Loss of TV + AF (DCT).
Refer to caption
(d) Top1 of Cosine + Real.
Refer to caption
(e) Top1 of TV + Softmax.
Refer to caption
(f) Top1 of TV + AF (DCT).
Figure 3: InfoNCE loss and Top-1 (Train) comparisons on STL10 dataset.
Refer to caption
(a) Loss of TV + Softmax.
Refer to caption
(b) Loss of TV + AF (DCT).
Refer to caption
(c) KNN comparison.
Figure 4: TWD loss for SimSiam models.

5 Experiments

This section evaluates SSL methods with different probability models.

5.1 Performance comparison for SimCLR

For all experiments, we employed the Resnet18 model with an output dimension of (dout=256subscript𝑑out256d_{\text{out}}=256italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT = 256) and coded all the methods based on a standard SimCLR implementation 222https://github.com/sthalles/SimCLR. We used the Adam optimizer and set the learning rate to 0.0003, the weight decay parameter to 1e-4, and temperature τ𝜏\tauitalic_τ to 0.07. For the proposed method, we compared two variants of TWD: total variation and ClusterTree ( Figure 1). As part of the model evaluation, we assessed the conventional softmax function, attention model (AF), and simplicial embedding (SEM) [28] and set the temperature parameter τ=0.1𝜏0.1\tau=0.1italic_τ = 0.1 for all experiments. For SEM, we set L=16𝐿16L=16italic_L = 16 and V=16𝑉16V=16italic_V = 16.

We also evaluated JD regularization, where we set the regularization parameter λ=0.1𝜆0.1\lambda=0.1italic_λ = 0.1 for all experiments. For reference, we compared cosine similarity as a similarity function of SimCLR. For all approaches, we utilized the KNN classifier of the scikit-learn package333https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html, where the number of nearest neighbor was set to K=50𝐾50K=50italic_K = 50. We utilized the L1 distance for Wasserstein distances and cosine similarity for non-probability-based models. All the experiments were computed on A6000 GPUs. We ran all experiments three times and report the average scores.

Table 1: KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K=50𝐾50K=50italic_K = 50 and computed the averaged classification accuracy over three runs. Note that the Wasserstein distance with (𝑩=𝑰dout𝑩subscript𝑰subscript𝑑out{\bm{B}}={\bm{I}}_{d_{\text{out}}}bold_italic_B = bold_italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUBSCRIPT) is equivalent to total variation.
Similarity Function probability model STL10 CIFAR10 CIFAR100 SVHN
Cosine Similarity N/A 75.77 ±plus-or-minus\pm± 0.47 67.39 ±plus-or-minus\pm± 0.46 32.06 ±plus-or-minus\pm± 0.06 76.35 ±plus-or-minus\pm± 0.39
Softmax 70.12 ±plus-or-minus\pm± 0.04 63.20 ±plus-or-minus\pm± 0.23 26.88 ±plus-or-minus\pm± 0.26 74.46 ±plus-or-minus\pm± 0.62
SEM 71.33 ±plus-or-minus\pm± 0.45 61.13 ±plus-or-minus\pm± 0.56 26.08 ±plus-or-minus\pm± 0.07 74.28 ±plus-or-minus\pm± 1.13
AF (DCT) 72.95 ±plus-or-minus\pm± 0.31 65.92 ±plus-or-minus\pm± 0.65 25.96 ±plus-or-minus\pm± 0.13 76.51 ±plus-or-minus\pm± 0.24
TWD (TV) Softmax 65.54 ±plus-or-minus\pm± 0.47 59.72 ±plus-or-minus\pm± 0.39 26.07 ±plus-or-minus\pm± 0.19 72.67 ±plus-or-minus\pm± 0.33
SEM 65.35 ±plus-or-minus\pm± 0.31 56.56 ±plus-or-minus\pm± 0.46 24.31 ±plus-or-minus\pm± 0.43 73.36 ±plus-or-minus\pm± 1.19
AF 65.61 ±plus-or-minus\pm± 0.56 60.92 ±plus-or-minus\pm± 0.42 26.33 ±plus-or-minus\pm± 0.42 75.01 ±plus-or-minus\pm± 0.32
AF (PE) 71.71 ±plus-or-minus\pm± 0.17 64.68 ±plus-or-minus\pm± 0.33 26.38 ±plus-or-minus\pm± 0.37 76.44 ±plus-or-minus\pm± 0.45
AF (DCT) 73.28 ±plus-or-minus\pm± 0.27 67.03 ±plus-or-minus\pm± 0.24 25.85 ±plus-or-minus\pm± 0.39 77.62 ±plus-or-minus\pm± 0.40
Softmax + JD 72.64 ±plus-or-minus\pm± 0.27 67.08 ±plus-or-minus\pm± 0.14 27.82 ±plus-or-minus\pm± 0.22 77.69 ±plus-or-minus\pm± 0.46
SEM + JD 71.79 ±plus-or-minus\pm± 0.92 63.60 ±plus-or-minus\pm± 0.50 26.14 ±plus-or-minus\pm± 0.40 75.64 ±plus-or-minus\pm± 0.44
AF + JD 72.64 ±plus-or-minus\pm± 0.37 67.15 ±plus-or-minus\pm± 0.27 27.45 ±plus-or-minus\pm± 0.37 78.00 ±plus-or-minus\pm± 0.15
AF (PE) + JD 74.47 ±plus-or-minus\pm± 0.10 67.28 ±plus-or-minus\pm± 0.65 27.01 ±plus-or-minus\pm± 0.39 78.12 ±plus-or-minus\pm± 0.48
AF (DCT) + JD 76.28 ±plus-or-minus\pm± 0.07 68.60 ±plus-or-minus\pm± 0.36 26.49 ±plus-or-minus\pm± 0.24 79.70 ±plus-or-minus\pm± 0.23
TWD (ClusterTree) Softmax 69.15 ±plus-or-minus\pm± 0.45 62.33 ±plus-or-minus\pm± 0.40 24.47 ±plus-or-minus\pm± 0.40 74.87 ±plus-or-minus\pm± 0.13
SEM 72.88 ±plus-or-minus\pm± 0.12 63.82 ±plus-or-minus\pm± 0.32 22.55 ±plus-or-minus\pm± 0.28 77.47 ±plus-or-minus\pm± 0.92
AF 70.40 ±plus-or-minus\pm± 0.40 63.28 ±plus-or-minus\pm± 0.57 24.28 ±plus-or-minus\pm± 0.15 75.24 ±plus-or-minus\pm± 0.52
AF (PE) 72.37 ±plus-or-minus\pm± 0.28 65.08 ±plus-or-minus\pm± 0.74 23.33 ±plus-or-minus\pm± 0.35 76.67 ±plus-or-minus\pm± 0.26
AF (DCT) 71.95 ±plus-or-minus\pm± 0.46 65.89 ±plus-or-minus\pm± 0.11 21.87 ±plus-or-minus\pm± 0.19 77.92 ±plus-or-minus\pm± 0.24
Softmax + JD 73.52 ±plus-or-minus\pm± 0.16 66.76 ±plus-or-minus\pm± 0.29 24.96 ±plus-or-minus\pm± 0.07 77.65 ±plus-or-minus\pm± 0.53
SEM + JD 75.93 ±plus-or-minus\pm± 0.14 67.68 ±plus-or-minus\pm± 0.46 22.96 ±plus-or-minus\pm± 0.28 79.19 ±plus-or-minus\pm± 0.53
AF + JD 73.66 ±plus-or-minus\pm± 0.23 66.61 ±plus-or-minus\pm± 0.32 24.55 ±plus-or-minus\pm± 0.14 77.64 ±plus-or-minus\pm± 0.19
AF (PE) + JD 73.92 ±plus-or-minus\pm± 0.57 67.00 ±plus-or-minus\pm± 0.13 23.83 ±plus-or-minus\pm± 0.42 77.87 ±plus-or-minus\pm± 0.29
AF (DCT) + JD 74.29 ±plus-or-minus\pm± 0.30 67.50 ±plus-or-minus\pm± 0.49 22.89 ±plus-or-minus\pm± 0.12 78.31 ±plus-or-minus\pm± 0.72

Figure 3 illustrates the training loss and top-1 accuracy for the three methods: cosine + real-valued embedding, TV + Softmax, and TV + AF (DCT). This experiment revealed that the convergence speed of the loss function was nearly identical across all methods. Regarding the training top-1 accuracy, cosine + real-valued embedding achieves the highest accuracy, followed by the Softmax function, and AF (DCT) lags. This behavior is expected because real-valued embeddings offer the most flexibility, followed by Softmax, with AF models exhibiting the least freedom. For all methods based on the TWD, JD regularization significantly aids the training process, particularly in the case of the Softmax function. However, for AF (DCT), the improvement was relatively small. This is likely because AF (DCT) can also be considered a form of regularization.

Table 1 presents the experimental results for the test classification accuracy using KNN. The first observation is that the simple implementation of the conventional Softmax function performs poorly (the performance is approximately 10 points lower) compared to cosine similarity. As expected, AF has only one more layer than the simple Softmax model, and performs similarly to Softmax. Compared to Softmax and AF, AF (PE) and AF (DCT) significantly improve the classification accuracy for the total variation and ClusterTree cases. However, for the ClusterTree case, AF (PE) achieves a better classification performance, whereas the AF (DCT) improvement over the softmax model is limited. In the ClusterTree case, SEM significantly improves with the combination of ClusterTree and regularization.

Overall, the proposed method performs better than cosine similarity without real-valued vector embedding when the number of classes is relatively small (i.e., STL10, CIFAR10, and SVHN). By contrast, the performance of the proposed method degrades for CIFAR100, and the results for ClusterTree are particularly poor. As the Wasserstein distance can be minimized even if it cannot overfit, it is natural for the Wasserstein distance to make mistakes when the number of classes is large.

Next, we evaluated the Jeffrey divergence regularization. Surprisingly, simple regularization dramatically improves the classification performance of all the probability models. These results support the idea that the main problem with Wasserstein distance-based representation learning is its numerical instability.

Among the methods, the proposed AF (DCT) + JD with total variation achieves the highest classification accuracy, comparable to the cosine similarity result, and achieves more than 10% improvement from the naive implementation with the Softmax function. Moreover, all probability model performances with the cosine similarity combination tend to result in a lower classification error than those with the combination of the TWD and probability models. Based on our empirical study, we propose utilizing TWD (TV) + AF models or TWD (ClusterTree) + SEM for representation-learning tasks in probability-based representation learning.

Table 2: SimSiam evaluation with CIFAR10 dataset.
Similarity Probability model Linear classifier
Cosine N/A 91.13 ±plus-or-minus\pm± 0.14
TWD (TV) Softmax + JD 9.99 ±plus-or-minus\pm± 0.00
AF (DCT) + JD 90.60 ±plus-or-minus\pm± 0.02

5.2 Performance comparison for SimSiam

Next, we evaluated the performance using a non-contrastive setup. For all experiments, we utilized the Resnet18-Cifar-Variant1 model with an output dimension of (dout=2048subscript𝑑out2048d_{\text{out}}=2048italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT = 2048) and implemented all methods based on a standard SimSiam framework444https://github.com/PatrickHua/SimSiam. The optimization was performed using the SGD optimizer with a base learning rate of 0.03, weight decay parameter of 0.00005, momentum parameter of 0.9, batch size of 512, and a fixed number of epochs set to 800. For the proposed method, we employed the total variation as a loss function, along with the softmax function and ArcFace model (AF). The temperature parameter τ𝜏\tauitalic_τ was set to 0.1 for all experiments. Additionally, we assessed JD regularization with the regularization parameter λ𝜆\lambdaitalic_λ set to 0.1 across all experiments. A100 GPUs were used for all experiments, and each experiment was run three times, with the reported results being the average scores.

We compared the proposed methods, TWDSimSiam (Softmax + JD) and TWDSimSiam (AF + JD), with the original SimSiam method which employs cosine similarity loss. Upon examination, we observe that learning the total variation with softmax encounters numerical issues, even with JD regularization (See Figures 3(a) and 3(c) in Appendix). Conversely, the AF + JD combination proved successful in training the models, as shown in Figures 3(b) and 3(c) in Appendix). One potential reason for the failure of TWD with Softmax is that the total variation can easily become zero because the softmax function lacks normalization. For TWDSimSiam (AF + JD), normalization within the AF model prevents convergence to a poor local minimum. From a performance standpoint, the utilization of cosine similarity and total variation (TV) yield comparable results. However, a key contribution of this study is the introduction of a practical approach to enhance the model training stability by incorporating Wasserstein distance, specifically through total variation. This discovery has a potential utility in various SSL tasks.

6 Conclusion

This study investigates SSL using TWD. We empirically evaluate several benchmark datasets and find that a simple combination of the softmax function and TWD performs poorly. To address this, we propose simplicial embedding [28] and ArcFace models [12] as probability models. Moreover, to mitigate the intricacies of optimizing TWD, we incorporate an upper bound on the squared 1-Wasserstein distance as a regularization technique. Overall, the combination of ArcFace and DCT outperforms their cosine similarity counterparts. Finally, we find that the combination of TWD (ClusterTree) and SEM yields favorable performance.

References

  • Arjovsky et al. [2017] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In ICML, 2017.
  • Backurs et al. [2020] Arturs Backurs, Yihe Dong, Piotr Indyk, Ilya Razenshteyn, and Tal Wagner. Scalable nearest neighbor search for optimal transport. In ICML, 2020.
  • Bahdanau et al. [2014] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.
  • Becker and Hinton [1992] Suzanna Becker and Geoffrey E Hinton. Self-organizing neural network that discovers surfaces in random-dot stereograms. Nature, 355(6356):161–163, 1992.
  • Caron et al. [2020] Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, and Armand Joulin. Unsupervised learning of visual features by contrasting cluster assignments. NeurIPS, 2020.
  • Caron et al. [2021] Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal, Piotr Bojanowski, and Armand Joulin. Emerging properties in self-supervised vision transformers. In ICCV, 2021.
  • Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In ICML, 2020a.
  • Chen and He [2021] Xinlei Chen and Kaiming He. Exploring simple siamese representation learning. In CVPR, 2021.
  • Chen et al. [2020b] Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020b.
  • Cover and Thomas [2012] Thomas M Cover and Joy A Thomas. Elements of information theory. John Wiley & Sons, 2012.
  • Cuturi [2013] Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In NIPS, 2013.
  • Deng et al. [2019] Jiankang Deng, Jia Guo, Niannan Xue, and Stefanos Zafeiriou. Arcface: Additive angular margin loss for deep face recognition. In CVPR, 2019.
  • Deshpande et al. [2019] Ishan Deshpande, Yuan-Ting Hu, Ruoyu Sun, Ayis Pyrros, Nasir Siddiqui, Sanmi Koyejo, Zhizhen Zhao, David Forsyth, and Alexander G Schwing. Max-sliced Wasserstein distance and its use for GANs. In CVPR, 2019.
  • Dey and Zhang [2022] Tamal K Dey and Simon Zhang. Approximating 1-wasserstein distance between persistence diagrams by graph sparsification. In ALENEX, 2022.
  • Evans and Matsen [2012] Steven N Evans and Frederick A Matsen. The phylogenetic kantorovich–rubinstein metric for environmental sequence samples. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 74(3):569–592, 2012.
  • Fan [1953] Ky Fan. Minimax theorems. Proceedings of the National Academy of Sciences, 39(1):42–47, 1953.
  • Frogner et al. [2015] Charlie Frogner, Chiyuan Zhang, Hossein Mobahi, Mauricio Araya, and Tomaso A Poggio. Learning with a wasserstein loss. NIPS, 2015.
  • Gretton et al. [2005] A. Gretton, O. Bousquet, Alex. Smola, and B. Schölkopf. Measuring statistical dependence with Hilbert-Schmidt norms. In ALT, 2005.
  • Grill et al. [2020] Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Guo, Mohammad Gheshlaghi Azar, et al. Bootstrap your own latent-a new approach to self-supervised learning. NeurIPS, 2020.
  • He et al. [2020] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. In CVPR, 2020.
  • Indyk and Thaper [2003] Piotr Indyk and Nitin Thaper. Fast image retrieval via embeddings. In 3rd international workshop on statistical and computational theories of vision, volume 2, page 5. Nice, France, 2003.
  • Jiang et al. [2023] Qian Jiang, Changyou Chen, Han Zhao, Liqun Chen, Qing Ping, Son Dinh Tran, Yi Xu, Belinda Zeng, and Trishul Chilimbi. Understanding and constructing latent modality structures in multi-modal representation learning. In CVPR, 2023.
  • Kingma and Welling [2013] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Kolouri et al. [2016] Soheil Kolouri, Yang Zou, and Gustavo K Rohde. Sliced wasserstein kernels for probability distributions. In CVPR, 2016.
  • Kolouri et al. [2019] Soheil Kolouri, Kimia Nadjahi, Umut Simsekli, Roland Badeau, and Gustavo Rohde. Generalized sliced wasserstein distances. In NeurIPS, 2019.
  • Kramer [1991] Mark A Kramer. Nonlinear principal component analysis using autoassociative neural networks. AIChE journal, 37(2):233–243, 1991.
  • Kusner et al. [2015] Matt Kusner, Yu Sun, Nicholas Kolkin, and Kilian Weinberger. From word embeddings to document distances. In ICML, 2015.
  • Lavoie et al. [2022] Samuel Lavoie, Christos Tsirigotis, Max Schwarzer, Ankit Vani, Michael Noukhovitch, Kenji Kawaguchi, and Aaron Courville. Simplicial embeddings in self-supervised learning and downstream classification. arXiv preprint arXiv:2204.00616, 2022.
  • Le and Nguyen [2021] Tam Le and Truyen Nguyen. Entropy partial transport with tree metrics: Theory and practice. In AISTATS, 2021.
  • Le et al. [2019] Tam Le, Makoto Yamada, Kenji Fukumizu, and Marco Cuturi. Tree-sliced approximation of wasserstein distances. NeurIPS, 2019.
  • Liu et al. [2020] Yanbin Liu, Linchao Zhu, Makoto Yamada, and Yi Yang. Semantic correspondence as an optimal transport problem. In CVPR, 2020.
  • Lozupone and Knight [2005] Catherine Lozupone and Rob Knight. Unifrac: a new phylogenetic method for comparing microbial communities. Applied and environmental microbiology, 71(12):8228–8235, 2005.
  • Mueller and Jaakkola [2015] Jonas W Mueller and Tommi Jaakkola. Principal differences analysis: Interpretable characterization of differences between distributions. NIPS, 2015.
  • Oord et al. [2018] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748, 2018.
  • Paty and Cuturi [2019] François-Pierre Paty and Marco Cuturi. Subspace robust wasserstein distances. In ICML, 2019.
  • Rabin et al. [2011] Julien Rabin, Gabriel Peyré, Julie Delon, and Marc Bernot. Wasserstein barycenter and its application to texture mixing. In International Conference on Scale Space and Variational Methods in Computer Vision, pages 435–446. Springer, 2011.
  • Raginsky et al. [2013] Maxim Raginsky, Igal Sason, et al. Concentration of measure inequalities in information theory, communications, and coding. Foundations and Trends® in Communications and Information Theory, 10(1-2):1–246, 2013.
  • Sarlin et al. [2020] Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. Superglue: Learning feature matching with graph neural networks. In CVPR, 2020.
  • Sato et al. [2020] Ryoma Sato, Makoto Yamada, and Hisashi Kashima. Fast unbalanced optimal transport on tree. In NeurIPS, 2020.
  • Sato et al. [2022] Ryoma Sato, Makoto Yamada, and Hisashi Kashima. Re-evaluating word mover’s distance. ICML, 2022.
  • Takezawa et al. [2021] Yuki Takezawa, Ryoma Sato, and Makoto Yamada. Supervised tree-wasserstein distance. In ICML, 2021.
  • Takezawa et al. [2022] Yuki Takezawa, Ryoma Sato, Zornitsa Kozareva, Sujith Ravi, and Makoto Yamada. Fixed support tree-sliced wasserstein barycenter. AISTATS, 2022.
  • Toyokuni et al. [2021] Ayato Toyokuni, Sho Yokoi, Hisashi Kashima, and Makoto Yamada. Computationally efficient Wasserstein loss for structured labels. In ECAL: Student Research Workshop, April 2021.
  • Tsai et al. [2021] Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, and Ruslan Salakhutdinov. A note on connecting barlow twins with negative-sample-free contrastive learning. arXiv preprint arXiv:2104.13712, 2021.
  • v. Neumann [1928] J v. Neumann. Zur theorie der gesellschaftsspiele. Mathematische annalen, 100(1):295–320, 1928.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. NIPS, 2017.
  • Xian et al. [2023] Ruicheng Xian, Lang Yin, and Han Zhao. Fair and optimal classification via post-processing. In ICML, 2023.
  • Yamada et al. [2022] Makoto Yamada, Yuki Takezawa, Ryoma Sato, Han Bao, Zornitsa Kozareva, and Sujith Ravi. Approximating 1-wasserstein distance with trees. TMLR, 2022.
  • Yokoi et al. [2020] Sho Yokoi, Ryo Takahashi, Reina Akama, Jun Suzuki, and Kentaro Inui. Word rotator’s distance. EMNLP, 2020.
  • Zbontar et al. [2021] Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, and Stéphane Deny. Barlow twins: Self-supervised learning via redundancy reduction. In ICML, 2021.
  • Zhao [2022] Han Zhao. Costs and benefits of fair regression. Transactions on Machine Learning Research, 2022.
  • Zhao et al. [2019] Wei Zhao, Maxime Peyrard, Fei Liu, Yang Gao, Christian M Meyer, and Steffen Eger. MoverScore: Text generation evaluating with contextualized embeddings and earth mover distance. EMNLP-IJCNLP, 2019.

Appendix A Appendix

A.1 Proof of Proposition 1

(Proof) Let 𝑩{0,1}N×Nleaf=[𝒃1,𝒃2,,𝒃Nleaf]𝑩superscript01𝑁subscript𝑁leafsubscript𝒃1subscript𝒃2subscript𝒃subscript𝑁leaf{\bm{B}}\in\{0,1\}^{N\times N_{\textnormal{leaf}}}=[{\bm{b}}_{1},{\bm{b}}_{2},% \ldots,{\bm{b}}_{N_{\textnormal{leaf}}}]bold_italic_B ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N × italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT end_POSTSUPERSCRIPT = [ bold_italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_b start_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT leaf end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] and 𝒃i{0,1}Nsubscript𝒃𝑖superscript01𝑁{\bm{b}}_{i}\in\{0,1\}^{N}bold_italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. The shortest-path distance between leaves i𝑖iitalic_i and j𝑗jitalic_j can be represented as [48]

d𝒯(𝒆i,𝒆j)subscript𝑑𝒯subscript𝒆𝑖subscript𝒆𝑗\displaystyle d_{\mathcal{T}}({\bm{e}}_{i},{\bm{e}}_{j})italic_d start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) =𝒘(𝒃i+𝒃j2𝒃i𝒃j).absentsuperscript𝒘topsubscript𝒃𝑖subscript𝒃𝑗2subscript𝒃𝑖subscript𝒃𝑗\displaystyle={\bm{w}}^{\top}({\bm{b}}_{i}+{\bm{b}}_{j}-2{\bm{b}}_{i}\circ{\bm% {b}}_{j}).= bold_italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - 2 bold_italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ bold_italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) .

That is, d𝒯(𝒆i,𝒆j)subscript𝑑𝒯subscript𝒆𝑖subscript𝒆𝑗d_{\mathcal{T}}({\bm{e}}_{i},{\bm{e}}_{j})italic_d start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is represented by a linear function with respect to 𝒘𝒘{\bm{w}}bold_italic_w for a given 𝑩𝑩{\bm{B}}bold_italic_B and the constraints on 𝒘𝒘{\bm{w}}bold_italic_w and 𝚷𝚷{\bm{\Pi}}bold_Π are convex. Thus, strong duality holds, and we obtain the following representation using the minimax theorem [45, 16]:

RTWD(μ,μ)RTWD𝜇superscript𝜇\displaystyle\textnormal{RTWD}(\mu,\mu^{\prime})RTWD ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) =12max𝒘s.t.𝑩𝒘=𝟏and𝒘𝟎min𝚷U(𝒂,𝒂)i=1Nleafsj=1Nleafsπij𝒘(𝒃i+𝒃j2𝒃i𝒃j)absent12subscript𝒘s.t.superscript𝑩top𝒘1and𝒘0subscript𝚷𝑈𝒂superscript𝒂superscriptsubscript𝑖1subscript𝑁leafssuperscriptsubscript𝑗1subscript𝑁leafssubscript𝜋𝑖𝑗superscript𝒘topsubscript𝒃𝑖subscript𝒃𝑗2subscript𝒃𝑖subscript𝒃𝑗\displaystyle=\frac{1}{2}\max_{{\bm{w}}~{}\textnormal{s.t.}~{}{\bm{B}}^{\top}{% \bm{w}}={\bm{1}}~{}\textnormal{and}~{}{\bm{w}}\geq{\bm{0}}}~{}~{}\min_{{\bm{% \Pi}}\in U({\bm{a}},{\bm{a}}^{\prime})}\sum_{i=1}^{N_{\textnormal{leafs}}}\sum% _{j=1}^{N_{\textnormal{leafs}}}\pi_{ij}{\bm{w}}^{\top}({\bm{b}}_{i}+{\bm{b}}_{% j}-2{\bm{b}}_{i}\circ{\bm{b}}_{j})= divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_max start_POSTSUBSCRIPT bold_italic_w s.t. bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 and bold_italic_w ≥ bold_0 end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT bold_Π ∈ italic_U ( bold_italic_a , bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT bold_italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - 2 bold_italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ bold_italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
=12max𝒘s.t.𝑩𝒘=𝟏and𝒘𝟎diag(𝒘)𝑩(𝒂𝒂)1,absent12subscript𝒘s.t.superscript𝑩top𝒘1and𝒘0subscriptnormdiag𝒘𝑩𝒂superscript𝒂1\displaystyle=\frac{1}{2}\max_{{\bm{w}}~{}\textnormal{s.t.}~{}{\bm{B}}^{\top}{% \bm{w}}={\bm{1}}~{}\textnormal{and}~{}{\bm{w}}\geq{\bm{0}}}\|\textnormal{diag}% ({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})\|_{1},= divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_max start_POSTSUBSCRIPT bold_italic_w s.t. bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 and bold_italic_w ≥ bold_0 end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ,

where TWD(μ,μ)=min𝚷U(𝒂,𝒂)i=1Nleafsj=1Nleafsπijd𝒯(𝒆i,𝒆j)=diag(𝒘)𝑩(𝒂𝒂)1TWD𝜇superscript𝜇subscript𝚷𝑈𝒂superscript𝒂superscriptsubscript𝑖1subscript𝑁leafssuperscriptsubscript𝑗1subscript𝑁leafssubscript𝜋𝑖𝑗subscript𝑑𝒯subscript𝒆𝑖subscript𝒆𝑗subscriptnormdiag𝒘𝑩𝒂superscript𝒂1\textnormal{TWD}(\mu,\mu^{\prime})=\min_{{\bm{\Pi}}\in U({\bm{a}},{\bm{a}}^{% \prime})}\sum_{i=1}^{N_{\textnormal{leafs}}}\sum_{j=1}^{N_{\textnormal{leafs}}% }\pi_{ij}d_{\mathcal{T}}({\bm{e}}_{i},{\bm{e}}_{j})=\|\textnormal{diag}({\bm{w% }}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})\|_{1}TWD ( italic_μ , italic_μ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = roman_min start_POSTSUBSCRIPT bold_Π ∈ italic_U ( bold_italic_a , bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = ∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Without loss of generality, we consider w0=0subscript𝑤00w_{0}=0italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0. First, we rewrite the norm diag(𝒘)𝑩(𝒂𝒂)1subscriptnormdiag𝒘𝑩𝒂superscript𝒂1\|\textnormal{diag}({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})\|_{1}∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT as

diag(𝒘)𝑩(𝒂𝒂)1=j=1Nwj|k[Nleafs],kde(j)(akak)|,subscriptnormdiag𝒘𝑩𝒂superscript𝒂1superscriptsubscript𝑗1𝑁subscript𝑤𝑗subscriptformulae-sequence𝑘delimited-[]subscript𝑁leafs𝑘𝑑𝑒𝑗subscript𝑎𝑘subscriptsuperscript𝑎𝑘\displaystyle\|\textnormal{diag}({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})% \|_{1}=\sum_{j=1}^{N}w_{j}\bigg{|}\sum_{k\in[N_{\textnormal{leafs}}],k\in de(j% )}(a_{k}-a^{\prime}_{k})\bigg{|},∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ] , italic_k ∈ italic_d italic_e ( italic_j ) end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) | ,

where de(j)𝑑𝑒𝑗de(j)italic_d italic_e ( italic_j ) denotes the set of descendants of node j[N]𝑗delimited-[]𝑁j\in[N]italic_j ∈ [ italic_N ] (including itself). Using the triangle inequality, we obtain

diag(𝒘)𝑩(𝒂𝒂)1subscriptnormdiag𝒘𝑩𝒂superscript𝒂1\displaystyle\|\textnormal{diag}({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})% \|_{1}∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT j=1Nwjk[Nleafs],kde(j)|akak|absentsuperscriptsubscript𝑗1𝑁subscript𝑤𝑗subscriptformulae-sequence𝑘delimited-[]subscript𝑁leafs𝑘𝑑𝑒𝑗subscript𝑎𝑘subscriptsuperscript𝑎𝑘\displaystyle\leq\sum_{j=1}^{N}w_{j}\sum_{k\in[N_{\textnormal{leafs}}],k\in de% (j)}|a_{k}-a^{\prime}_{k}|≤ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ] , italic_k ∈ italic_d italic_e ( italic_j ) end_POSTSUBSCRIPT | italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT |
=k[Nleafs]|akak|j[N],jpa(k)wj,absentsubscript𝑘delimited-[]subscript𝑁leafssubscript𝑎𝑘subscriptsuperscript𝑎𝑘subscriptformulae-sequence𝑗delimited-[]𝑁𝑗𝑝𝑎𝑘subscript𝑤𝑗\displaystyle=\sum_{k\in[N_{\textnormal{leafs}}]}|a_{k}-a^{\prime}_{k}|\sum_{j% \in[N],j\in pa(k)}w_{j},= ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT | italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_N ] , italic_j ∈ italic_p italic_a ( italic_k ) end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ,

where pa(k)𝑝𝑎𝑘pa(k)italic_p italic_a ( italic_k ) is the set of ancestors of leaf k𝑘kitalic_k (including itself). By rewriting the constraint 𝑩𝒘=𝟏superscript𝑩top𝒘1{\bm{B}}^{\top}{\bm{w}}={\bm{1}}bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 as j[N],jpa(k)wj=1subscriptformulae-sequence𝑗delimited-[]𝑁𝑗𝑝𝑎𝑘subscript𝑤𝑗1\sum_{j\in[N],j\in pa(k)}w_{j}=1∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_N ] , italic_j ∈ italic_p italic_a ( italic_k ) end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 1 for any k[Nleafs]𝑘delimited-[]subscript𝑁leafsk\in[N_{\textnormal{leafs}}]italic_k ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ], we obtain:

diag(𝒘)𝑩(𝒂𝒂)1k[Nleafs]|akak|=𝒂𝒂1.subscriptnormdiag𝒘𝑩𝒂superscript𝒂1subscript𝑘delimited-[]subscript𝑁leafssubscript𝑎𝑘subscriptsuperscript𝑎𝑘subscriptnorm𝒂superscript𝒂1\displaystyle\|\textnormal{diag}({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})% \|_{1}\leq\sum_{k\in[N_{\textnormal{leafs}}]}|a_{k}-a^{\prime}_{k}|=\|{\bm{a}}% -{\bm{a}}^{\prime}\|_{1}.∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT | italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | = ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT .

The latter inequality holds for any weight vector 𝒘𝒘{\bm{w}}bold_italic_w. Therefore, considering the vector such that wj=1subscript𝑤𝑗1w_{j}=1italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 1 if j[Nleafs]𝑗delimited-[]subscript𝑁leafsj\in[N_{\text{leafs}}]italic_j ∈ [ italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT ] and 0 otherwise, which satisfies the constraint 𝑩𝒘=𝟏superscript𝑩top𝒘1{\bm{B}}^{\top}{\bm{w}}={\bm{1}}bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1, we obtain

diag(𝒘)𝑩(𝒂𝒂)1=k=1Nleafs|akak|=𝒂𝒂1.subscriptnormdiag𝒘𝑩𝒂superscript𝒂1superscriptsubscript𝑘1subscript𝑁leafssubscript𝑎𝑘subscriptsuperscript𝑎𝑘subscriptnorm𝒂superscript𝒂1\displaystyle\|\textnormal{diag}({\bm{w}}){\bm{B}}({\bm{a}}-{\bm{a}}^{\prime})% \|_{1}=\sum_{k=1}^{N_{\textnormal{leafs}}}|a_{k}-a^{\prime}_{k}|=\|{\bm{a}}-{% \bm{a}}^{\prime}\|_{1}.∥ diag ( bold_italic_w ) bold_italic_B ( bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT leafs end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | = ∥ bold_italic_a - bold_italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT .

This completes the proof of the proposition.

A.2 Proof of Proposition 2

(Proof) The following holds if 𝑩𝒘=𝟏superscript𝑩top𝒘1{\bm{B}}^{\top}{\bm{w}}={\bm{1}}bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_w = bold_1 with the probability vector 𝒂𝒂{\bm{a}}bold_italic_a (such that 𝒂𝟏=1superscript𝒂top11{\bm{a}}^{\top}{\bm{1}}=1bold_italic_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_1 = 1).

𝟏diag(𝒘)𝑩𝒂=1.superscript1topdiag𝒘𝑩𝒂1\displaystyle{\bm{1}}^{\top}\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}=1.bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT diag ( bold_italic_w ) bold_italic_B bold_italic_a = 1 .

Then, using Pinsker’s inequality, we derive the following inequalities:

W𝒯(μi,μj)=diag(𝒘)𝑩𝒂idiag(𝒘)𝑩𝒂j12KL(diag(𝒘)𝑩𝒂idiag(𝒘)𝑩𝒂j),subscript𝑊𝒯subscript𝜇𝑖subscript𝜇𝑗subscriptnormdiag𝒘𝑩subscript𝒂𝑖diag𝒘𝑩subscript𝒂𝑗12KLconditionaldiag𝒘𝑩subscript𝒂𝑖diag𝒘𝑩subscript𝒂𝑗\displaystyle W_{{\mathcal{T}}}(\mu_{i},\mu_{j})=\|\textnormal{diag}({\bm{w}})% {\bm{B}}{\bm{a}}_{i}-\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j}\|_{1}\leq% \sqrt{2\textnormal{KL}(\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i}\|% \textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j})},italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ square-root start_ARG 2 KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG ,

and

W𝒯(μi,μj)=diag(𝒘)𝑩𝒂jdiag(𝒘)𝑩𝒂i12KL(diag(𝒘)𝑩𝒂jdiag(𝒘)𝑩𝒂i),subscript𝑊𝒯subscript𝜇𝑖subscript𝜇𝑗subscriptnormdiag𝒘𝑩subscript𝒂𝑗diag𝒘𝑩subscript𝒂𝑖12KLconditionaldiag𝒘𝑩subscript𝒂𝑗diag𝒘𝑩subscript𝒂𝑖\displaystyle W_{{\mathcal{T}}}(\mu_{i},\mu_{j})=\|\textnormal{diag}({\bm{w}})% {\bm{B}}{\bm{a}}_{j}-\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i}\|_{1}\leq% \sqrt{2\textnormal{KL}(\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j}\|% \textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{i})},italic_W start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ square-root start_ARG 2 KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG ,

Thus,

W𝒯2(μi,μj)subscriptsuperscript𝑊2𝒯subscript𝜇𝑖subscript𝜇𝑗\displaystyle W^{2}_{{\mathcal{T}}}(\mu_{i},\mu_{j})italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) KL(diag(𝒘)𝑩𝒂idiag(𝒘)𝑩𝒂j)+KL(diag(𝒘)𝑩𝒂jdiag(𝒘)𝑩𝒂i)absentKLconditionaldiag𝒘𝑩subscript𝒂𝑖diag𝒘𝑩subscript𝒂𝑗KLconditionaldiag𝒘𝑩subscript𝒂𝑗diag𝒘𝑩subscript𝒂𝑖\displaystyle\leq{\textnormal{KL}(\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_% {i}\|\textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j})}+{\textnormal{KL}(% \textnormal{diag}({\bm{w}}){\bm{B}}{\bm{a}}_{j}\|\textnormal{diag}({\bm{w}}){% \bm{B}}{\bm{a}}_{i})}≤ KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + KL ( diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ diag ( bold_italic_w ) bold_italic_B bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )

A.3 Ablation study

Table 3: KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K=50𝐾50K=50italic_K = 50 and computed the averaged classification accuracy over three runs.
Similarity Function λ𝜆\lambdaitalic_λ STL10 CIFAR10 CIFAR100 SVHN
TWD (TV) 0.00.00.00.0 73.28 ±plus-or-minus\pm± 0.27 67.03 ±plus-or-minus\pm± 0.24 25.85 ±plus-or-minus\pm± 0.39 77.62 ±plus-or-minus\pm± 0.40
0.10.10.10.1 76.28 ±plus-or-minus\pm± 0.07 68.60 ±plus-or-minus\pm± 0.36 26.49 ±plus-or-minus\pm± 0.24 79.70 ±plus-or-minus\pm± 0.23
0.20.20.20.2 77.40 ±plus-or-minus\pm± 0.17 68.48 ±plus-or-minus\pm± 0.11 25.59 ±plus-or-minus\pm± 0.16 79.67 ±plus-or-minus\pm± 0.26
0.30.30.30.3 77.67 ±plus-or-minus\pm± 0.06 68.26 ±plus-or-minus\pm± 0.51 24.21 ±plus-or-minus\pm± 0.35 79.91 ±plus-or-minus\pm± 0.42

A.3.1 Effect of number of nearest neighbors

In this section, we assess the performance of the KNN model by varying the number of nearest neighbors and setting K𝐾Kitalic_K to 10 or 50. The results for K=10𝐾10K=10italic_K = 10 are presented in Table 4, and Table 5 illustrates a comparison of the best models across different nearest neighbor values. Our experiments revealed that utilizing K=50𝐾50K=50italic_K = 50 tends to enhance the performance, and the relative order of the results remains consistent, regardless of the number of nearest neighbors.

A.3.2 Effect of the regularization parameter for Jeffrey-Divergence

In this experiment, we evaluated model performance by varying the regularization parameter, denoted as λ𝜆\lambdaitalic_λ. The results indicate a noteworthy improvement in performance with the introduction of regularization parameters. However, it was observed that the performance did not exhibit significant changes across different values of λ𝜆\lambdaitalic_λ, and setting λ=0.1𝜆0.1\lambda=0.1italic_λ = 0.1 yielded favorable results.

Table 4: KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K=10𝐾10K=10italic_K = 10 and computed the averaged classification accuracy over three runs. Note that the Wasserstein distance with (𝑩=𝑰dout𝑩subscript𝑰subscript𝑑out{\bm{B}}={\bm{I}}_{d_{\text{out}}}bold_italic_B = bold_italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUBSCRIPT) is equivalent to a total variation.
Similarity Function probability model STL10 CIFAR10 CIFAR100 SVHN
Cosine Similarity N/A 75.44 ±plus-or-minus\pm± 0.21 66.96 ±plus-or-minus\pm± 0.45 31.63 ±plus-or-minus\pm± 0.25 74.71 ±plus-or-minus\pm± 0.31
Softmax 71.25 ±plus-or-minus\pm± 0.30 63.80 ±plus-or-minus\pm± 0.48 26.18 ±plus-or-minus\pm± 0.36 73.06 ±plus-or-minus\pm± 0.47
SEM 71.34 ±plus-or-minus\pm± 0.31 61.26 ±plus-or-minus\pm± 0.42 25.40 ±plus-or-minus\pm± 0.06 73.41 ±plus-or-minus\pm± 0.95
AF (DCT) 72.15 ±plus-or-minus\pm± 0.53 65.52 ±plus-or-minus\pm± 0.45 24.93 ±plus-or-minus\pm± 0.24 75.68 ±plus-or-minus\pm± 0.13
TWD (TV) Softmax 63.42 ±plus-or-minus\pm± 0.24 59.03 ±plus-or-minus\pm± 0.58 24.95 ±plus-or-minus\pm± 0.31 70.87 ±plus-or-minus\pm± 0.29
SEM 63.72 ±plus-or-minus\pm± 0.17 55.57 ±plus-or-minus\pm± 0.35 23.40 ±plus-or-minus\pm± 0.36 71.69 ±plus-or-minus\pm± 0.75
AF 63.97 ±plus-or-minus\pm± 0.05 59.96 ±plus-or-minus\pm± 0.44 25.29 ±plus-or-minus\pm± 0.17 73.44 ±plus-or-minus\pm± 0.35
AF (PE) 71.04 ±plus-or-minus\pm± 0.37 64.28 ±plus-or-minus\pm± 0.14 25.71 ±plus-or-minus\pm± 0.20 75.70 ±plus-or-minus\pm± 0.42
AF (DCT) 72.75 ±plus-or-minus\pm± 0.11 67.01 ±plus-or-minus\pm± 0.03 24.95 ±plus-or-minus\pm± 0.17 76.98 ±plus-or-minus\pm± 0.44
Softmax + JD 72.05 ±plus-or-minus\pm± 0.30 66.61 ±plus-or-minus\pm± 0.20 26.91 ±plus-or-minus\pm± 0.19 76.65 ±plus-or-minus\pm± 0.56
SEM + JD 70.73 ±plus-or-minus\pm± 0.89 62.75 ±plus-or-minus\pm± 0.61 24.83 ±plus-or-minus\pm± 0.27 74.71 ±plus-or-minus\pm± 0.43
AF + JD 71.74 ±plus-or-minus\pm± 0.19 66.74 ±plus-or-minus\pm± 0.20 26.68 ±plus-or-minus\pm± 0.35 77.10 ±plus-or-minus\pm± 0.04
AF (PE) + JD 74.10 ±plus-or-minus\pm± 0.20 66.82 ±plus-or-minus\pm± 0.36 26.17 ±plus-or-minus\pm± 0.00 77.55 ±plus-or-minus\pm± 0.50
AF (DCT) + JD 76.24 ±plus-or-minus\pm± 0.22 68.62 ±plus-or-minus\pm± 0.40 25.70 ±plus-or-minus\pm± 0.14 79.28 ±plus-or-minus\pm± 0.22
TWD (Clust) Softmax 67.95 ±plus-or-minus\pm± 0.42 61.59 ±plus-or-minus\pm± 0.29 23.34 ±plus-or-minus\pm± 0.26 73.88 ±plus-or-minus\pm± 0.05
SEM 72.43 ±plus-or-minus\pm± 0.11 63.63 ±plus-or-minus\pm± 0.42 21.29 ±plus-or-minus\pm± 0.28 77.04 ±plus-or-minus\pm± 0.77
AF 69.09 ±plus-or-minus\pm± 0.05 62.49 ±plus-or-minus\pm± 0.45 22.56 ±plus-or-minus\pm± 0.25 74.31 ±plus-or-minus\pm± 0.40
AF (PE) 72.08 ±plus-or-minus\pm± 0.07 64.56 ±plus-or-minus\pm± 0.31 22.51 ±plus-or-minus\pm± 0.29 75.98 ±plus-or-minus\pm± 0.23
AF (DCT) 71.64 ±plus-or-minus\pm± 0.15 65.51 ±plus-or-minus\pm± 0.36 21.04 ±plus-or-minus\pm± 0.10 77.59 ±plus-or-minus\pm± 0.25
Softmax + JD 73.07 ±plus-or-minus\pm± 0.13 66.38 ±plus-or-minus\pm± 0.27 23.97 ±plus-or-minus\pm± 0.11 76.82 ±plus-or-minus\pm± 0.50
SEM + JD 75.50 ±plus-or-minus\pm± 0.15 67.44 ±plus-or-minus\pm± 0.10 21.90 ±plus-or-minus\pm± 0.19 78.91 ±plus-or-minus\pm± 0.30
AF + JD 72.70 ±plus-or-minus\pm± 0.08 66.12 ±plus-or-minus\pm± 0.26 23.50 ±plus-or-minus\pm± 0.21 76.92 ±plus-or-minus\pm± 0.06
AF (PE) + JD 73.66 ±plus-or-minus\pm± 0.47 66.58 ±plus-or-minus\pm± 0.01 22.86 ±plus-or-minus\pm± 0.02 77.44 ±plus-or-minus\pm± 0.30
AF (DCT) + JD 73.79 ±plus-or-minus\pm± 0.12 67.34 ±plus-or-minus\pm± 0.38 21.96 ±plus-or-minus\pm± 0.34 78.00 ±plus-or-minus\pm± 0.60
Table 5: KNN classification accuracy with different number of neighbors.
Similarity Function Nearest neighbors (K𝐾Kitalic_K) STL10 CIFAR10 CIFAR100 SVHN
TWD (TV) K=10𝐾10K=10italic_K = 10 76.24 ±plus-or-minus\pm± 0.22 68.62 ±plus-or-minus\pm± 0.40 25.70 ±plus-or-minus\pm± 0.14 79.28 ±plus-or-minus\pm± 0.22
K=50𝐾50K=50italic_K = 50 76.28 ±plus-or-minus\pm± 0.07 68.60 ±plus-or-minus\pm± 0.36 26.49 ±plus-or-minus\pm± 0.24 79.70 ±plus-or-minus\pm± 0.23