Weakly Supervised Disentangled Generative Causal Representation Learning
Weakly Supervised Disentangled Generative Causal Representation Learning
Weakly Supervised Disentangled Generative Causal Representation Learning
Abstract
c 2022 Xinwei Shen, Furui Liu, Hanze Dong, Qing Lian, Zhitang Chen, and Tong Zhang.
License: CC-BY 4.0, see https://creativecommons.org/licenses/by/4.0/. Attribution requirements are provided
at http://jmlr.org/papers/v23/21-0080.html.
Shen, Liu, Dong, Lian, Chen, and Zhang
fits of the learned representations for downstream tasks in terms of sample efficiency and
distributional robustness.
Keywords: disentanglement, causality, representation learning, deep generative model
1. Introduction
Consider the observed data x from a distribution qx on X ⊆ Rd and the latent variable z
from a prior pz on Z ⊆ Rk . In bidirectional generative models (BGMs), we are normally
interested in learning an encoder E : X → Z to infer latent variables and a generator
G : Z → X to generate data, to achieve both representation learning and data generation.
Classical BGMs include Variational Autoencoder (VAE) (Kingma and Welling, 2014) and
BiGAN (Donahue et al., 2017; Dumoulin et al., 2017). In representation learning, it was
argued that an effective representation for downstream learning tasks should disentangle
the underlying factors of variation (Bengio et al., 2013). In generative modeling, it is highly
desirable if one can control the semantic generative factors by aligning them with the latent
variables such as in StyleGAN (Karras et al., 2019). Both goals can be achieved with
the disentanglement of latent variable z, which informally means that each dimension of z
measures a distinct factor of variation in the data (Bengio et al., 2013).
Earlier unsupervised disentanglement methods mostly regularized the VAE objective
to encourage independence of learned representations (Higgins et al., 2017; Burgess et al.,
2017; Kim and Mnih, 2018; Chen et al., 2018; Kumar et al., 2018). Later, Locatello et al.
(2019) showed that unsupervised learning of disentangled representations is impossible:
many existing unsupervised methods are brittle, requiring careful supervised hyperparam-
eter tuning or implicit inductive biases. To promote identifiability, recent work resorted
to various forms of supervision (Locatello et al., 2020b; Shu et al., 2020; Locatello et al.,
2020a). In this work, we also incorporate supervision on the ground-truth factors in the
form of a certain number of annotated labels as described in Section 3.2. We will present
experimental results showing that our method remains competitive with a small amount of
labeled data (a minimum of around 100 samples).
Most of the existing methods, including those mentioned above, are built on the as-
sumption that the underlying factors of variation are mutually independent. However, in
many real-world cases, the semantically meaningful factors of interests are not indepen-
dent (Bengio et al., 2020). Instead, such high-level variables are often causally related, i.e.,
connected by a causal graph.
In this paper, we prove formally that methods with independent priors fail to disentangle
causally related factors. Motivated by this observation, we propose a new method to learn
disentangled generative causal representations called DEAR. The key ingredient of our
formulation is a structural causal model (SCM) (Pearl et al., 2000) as the prior for latent
variables in a bidirectional generative model. As discussed in Section 4.1.2, we assume that
a super-graph of the underlying causal graph is known a priori, which ranges from the causal
ordering of the nodes in the graph to the true causal structure. The causal model prior is
then learned jointly with a generator and an encoder using a suitable GAN (Goodfellow
et al., 2014) algorithm. Moreover, we establish theoretical guarantees for DEAR on how it
resolves the unidentifiability issue of many existing methods as well as on the asymptotic
convergence of the proposed algorithm.
2
Weakly Supervised Disentangled Generative Causal Representation Learning
• Extensive experiments are conducted on both synthesized and real data to demon-
strate the effectiveness of DEAR in causal controllable generation, and the benefits
of the learned representations for downstream tasks in terms of sample efficiency and
distributional robustness.
Notation Throughout the paper, all distributions are assumed to be absolutely continuous
with respect to Lebesgue measure unless indicated otherwise. For a vector x, let [x]i denote
the i-th component of x. For a scalar function h(x, y), let ∇x h(x, y) denote its gradient with
respect to x and ∇2x h(x, y) denote its Hessian matrix with respect to x. For a vector function
g(x, y), let ∇x g(x, y) denote its Jacobian matrix with respect to x. Without ambiguity, ∇x
is denoted by ∇ for simplicity. Notation k · k stands for the Euclidean norm.
h(x) − h∗ ≤ ck∇h(x)k22 .
1. Note that the identifiability in this work differs from that in Khemakhem et al. (2020) in terms of goals
and assumptions. See more discussions in the related work and below Proposition 5.
3
Shen, Liu, Dong, Lian, Chen, and Zhang
2. Related work
VAE-based disentanglement methods. A number of methods have been proposed to
enrich the VAE loss by various regularizers to enforce the independence of the latent vari-
ables. β-VAE (Higgins et al., 2017) and Annealed VAE (Burgess et al., 2017) introduced
extra constraints on the capacity of the latent bottleneck by adjusting the role of the KL
term; Factor-VAE (Kim and Mnih, 2018) and β-TCVAE (Chen et al., 2018) encouraged the
aggregated posterior (i.e., the marginal distribution of E(x)) to be factorized by penalizing
its total correlation; DIP-VAE (Kumar et al., 2018) enforced a factorized aggregated poste-
rior differently by matching its moments with those of a factorized prior. Going beyond the
independence perspective, Suter et al. (2019) considered disentangled causal mechanisms,
meaning that all the generative factors are conditionally independent given a common con-
founder. This is one special case of causal relationship, while we consider more general
cases where the factors can have more complex causal relationships, e.g., one factor can be
a direct cause of another one.
Based on the above methods, Locatello et al. (2020b) and Locatello et al. (2020a) further
incorporated supervised information on a few labels of the generative factors and pairs of
observations which differ by a few factors respectively, where the former is more related to
ours which is discussed detailedly in Section 3.2. Shu et al. (2020) proposed several concepts
related to disentanglement, based on which they analyzed three forms of weak supervision
including restricted labeling, match pairing, and rank pairing.
Going beyond the independent prior, Khemakhem et al. (2020) proposed a conditional
VAE where the latent variables are assumed to be conditionally independent given some
additionally observed variables. Built upon developments of nonlinear ICA, they presented
the first principled identifiability theory of latent variable models, in particular VAEs, thus
leading to a form of provable disentanglement under suitable conditions. Our work, in
contrast, does not aim at achieving general identifiability of latent variable models or general
provable disentanglement, but contributes to resolving the failure of existing methods in
disentangling causally related factors. With this motivation, we consider more general
model assumptions on the latent structure as well as generating transformations than those
in Khemakhem et al. (2020) which apply more suitably to real-world data. To achieve
disentanglement of causal factors, we need to adopt a more direct and somehow stronger
form of supervision than Khemakhem et al. (2020), i.e., we require annotated labels of true
factors for a possibly small number of samples. See Appendix C for a discussion on the
two forms of supervision. The model in Khemakhem et al. (2020), however, has not yet
been applied with the most advanced network architecture for image generation such as
4
Weakly Supervised Disentangled Generative Causal Representation Learning
StyleGAN (Karras et al., 2019), nor can their conditional independent prior models the
causal structure of true factors. Therefore, their model and theory do not apply here and
our work should be regarded complementary.
To avoid the unidentifiability of the standard Gaussian prior caused by rotation transfor-
mations, Stühmer et al. (2020) proposed hierarchical non-Gaussian priors for unsupervised
disentanglement, which is not rotationally invariant. However, there remains other kinds
of mixing transformations that leave these priors invariant, leading to unidentifiability. Be-
sides, their proposed priors cannot model the causal relationships.
Recently, a concurrent work by Träuble et al. (2021) conducted a large-scale empirical
study to investigate the behavior of the most prominent disentanglement approaches on cor-
related data. In particular, they considered the case where the ground-truth factors exhibit
pairwise correlation. Although pairwise correlation largely generalizes the independence
assumption, it is less general than the causal correlation that we consider. For example, a
parental node with multiple children immediately goes beyond pairwise correlation. More-
over, Träuble et al. (2021) focused on verifying the problem that existing methods fail to
learn disentangled representations for strongly correlated factors, while we identify the prob-
lem as a motivation to propose a method to resolve it and learn disentangled representations
under the causal case.
GAN-based disentanglement methods. Existing GAN-based methods, including In-
foGAN (Chen et al., 2016) and InfoGAN-CR (Lin et al., 2020), differed from our proposed
formulation mainly in two folds. First, they still assumed an independent prior for latent
variables, so suffered from the same problem with the previous VAE-based methods men-
tioned above. Besides, the idea of InfoGAN-CR was to encourage each latent code to make
changes that are easy to detect, which only applies well when the underlying factors are
independent. Second, as a bidirectional generative modeling method, InfoGAN further re-
quired variational approximation apart from adversarial training, which is inferior to the
principled formulation in BiGAN and AGES (Shen et al., 2020) that we adopt.
Generative modeling involving causal models in the latent space. CausalGAN (Ko-
caoglu et al., 2018) and a concurrent work (Moraffah et al., 2020) of ours, were unidirectional
generative models (i.e., a generative model that learns a single mapping from the latent vari-
able to data) that build upon a cGAN (Mirza and Osindero, 2014). They assigned an SCM
to the conditional attributes while leaving the latent variables as independent Gaussian
noises. The limit of a cGAN is that it always requires full supervision on attributes to
apply conditional adversarial training. Also, the ground-truth factors were directly fed into
the generator as the conditional attributes, without any extra effort to align the dimensions
between the latent variables and the underlying factors, so their models had nothing to do
with disentanglement learning. Moreover, their unidirectional nature made them unable
to learn representations. Besides, they only considered binary factors, so the consequent
semantic interpolations appear non-smooth, as shown in Appendix G.
CausalVAE (Yang et al., 2021) assigned the SCM directly on the latent variables, while
built upon iVAE (Khemakhem et al., 2020), it adopted a conditional prior given the ground-
truth factors so was also limited to a fully supervised setting.
GraphVAE (He et al., 2018) generalized the chain-structured latent space proposed in
Ladder VAE (Sønderby et al., 2016) and imposed an SCM into the latent space of VAE.
5
Shen, Liu, Dong, Lian, Chen, and Zhang
The motivation behind GraphVAE is to improve the expressive capacity of VAE rather than
to disentangle the underlying causal factors as ours. Purely from observational data and
without any supervision on the underlying factors, the impossibility result from Locatello
et al. (2019) indicated that a VAE model cannot identify the true factors. Therefore, the
representations learned by GraphVAE were not guaranteed to disentangle the generative
factors, and consequently the learned SCM did not reflect the true causal structure in
principle. Moreover, their adopted VAE loss (ELBO) required an explicit form of KL
divergence between the prior and the posterior, which limited the model choice for the SCM.
Specifically, GraphVAE used an additive noise model with Gaussian noises. In contrast,
our method does not require the distribution induced by the SCM to be explicitly expressed
and in principle allows any SCMs that can be reparametrized as a generative model (i.e.,
given the exogenous noises, one can generate all the variables by ancestral sampling). For
comparison, in our experiments, we include a baseline which extends the original GraphVAE
method to incorporate the same amount of supervision as ours.
Generative modeling involving other structured latent spaces. VLAE (Zhao et al.,
2017) decomposed the latent space into separate chunks each of which is processed at dif-
ferent levels of the encoder and decoder. VQ-VAE-2 (Razavi et al., 2019) used a two-level
latent space along with a multi-stage generation mechanism to capture both high and low
level information of data. SAE (Leeb et al., 2020) encouraged a hierarchical structure in
the latent space through the structural architecture of the decoder. These methods essen-
tially adopted implicit probabilistic or architectural hierarchies, in contrast to the causal
structure that we impose to the latent space, and thus cannot achieve the goal of causal
disentanglement. For example, the hierarchy in SAE represents the level of abstraction, in
the sense that more high-level, abstract features are processed deeper in the decoder and
low-level, linear features are treated towards the end of the network. Such hierarchy differs
essentially from the causal structure that we consider.
Other works considered inferring the latent causal structure from visual data in the
reinforcement learning setting (Dasgupta et al., 2019; Nair et al., 2019). In particular, Nair
et al. (2019) developed learning-based approaches to induce causal knowledge in the form of
directed acyclic graphs, which was then utilized in learning goal-conditioned policies. The
interactive environment enables the agent to perform actions and observe their outcomes.
Therefore, the resulting data involves various interventions each of which entails an SCM and
thus is essentially different from the common setting in the disentanglement literature which
is also considered in this paper, where the observed data are independent and identically
distributed.
3. Problem setting
6
Weakly Supervised Disentangled Generative Causal Representation Learning
Ex∼qx [−EqE (z|x) log pG (x|z) + DKL (qE (z|x), pz (z))], (2)
used in VAEs up to a constant, and ELBO allows a closed form to be optimized easily only
with factorized Gaussian prior, encoder and generator (Shen et al., 2020).
Since constraints on the latent space are required to enforce disentanglement, it is desir-
able that the distribution families of qE (x, z) and pG (x, z) should be large enough, especially
for complex data like images. As demonstrated in literature on image generation (Karras
et al., 2019; Mescheder et al., 2017), implicit distributions, where the randomness is fed
into the input or intermediate layers of the network, are favored over factorized Gaussians
in terms of expressiveness. Then minimizing (1) requires adversarial training, as discussed
detailedly in Section 4.3.
7
Shen, Liu, Dong, Lian, Chen, and Zhang
where the supervised regularizer is Lsup = Ex,y [ls (E; x, y)] with ls = m
P
i=1 CE([Ē(x)]i , [y]i )
if [y]i is the binary or bounded (and normalized to [0, 1]) continuous label of factor [ξ]i ,
where CE(l, y) = −y log σ(l) P− (1 − y) log(1 − σ(l)) is the cross-entropy loss with σ(·) being
the sigmoid function; ls = m i=1 ([Ē(x)]i − [y] i )2 if [y] is the continuous observation of [ξ] .
i i
λ > 0 is the coefficient to balance both terms. Through ablation studies in Section 5.4, we
empirically find the choice of λ insensitive to different tasks and data sets, and hence set
λ = 5 in all experiments.
Note that in the objective (3), the unsupervised generative modeling loss and the super-
vised regularizer are decoupled in terms of taking expectations, in contrast to the conditional
GANs where supervised labels are involved in the GAN loss. This enables one to use two
separate samples with different sample sizes to estimate the two terms in (3) during train-
ing. Since in practice we may only have access to a limited amount of annotated labels, this
property makes the formulation applicable in such semi-supervised settings. In the exper-
iments, we conduct ablation studies to investigate how our method performs with varying
amounts of labeled samples available.
In addition, Locatello et al. (2020b) propose a regularizer Lsup = m
P
i=1 Ex,z (CE([Ē(x)]i ,
[z]i )) involving only the latent variable z which is a part of the generative model, without
distinguishing the model component z from the ground-truth factor ξ and its observation y.
Hence they do not establish formal theoretical justification on disentanglement. Moreover,
they follow the earlier VAE-based methods to adopt a VAE loss (2) for generative modeling
with an independent prior and an additional regularizer to enforce independence of the latent
variables, which suffers from the unidentifiability problem described in the next section.
Note that in general, the goal of disentanglement allows for permutations in the ground-
truth factors. For example one may expect for all i there exists j which is not necessarily
equal to i such that [E(x)]i = gj ([ξ]j ). However since in our method we supervise each latent
dimension by the annotated label of each ground-truth factor, we can expect a component-
wise correspondence between E(x) and ξ, as justified formally in Proposition 5 below.
As introduced above, we consider the general case where the underlying factors of inter-
ests are causally related. Then the goal becomes to disentangle the causal factors. Previous
methods mostly use an independent prior for z, which contradicts the truth. We make this
formal through the following proposition, which indicates that the disentangled representa-
tion is generally unidentifiable with an independent prior.
8
Weakly Supervised Disentangled Generative Causal Representation Learning
This proposition directly suggests that minimizing (3) favors an entangled solution
(E 0 , G0 ) over the one with a disentangled encoder E ∗ . Thus, with an independent prior
we have no way to identify the disentangled solution with λ that is not large enough. How-
ever, in real applications, it is impossible to estimate the threshold, and too large λ makes
it difficult to learn the BGM. After our work was submitted, we were brought attention to a
theoretical result in Träuble et al. (2021) that is similar to our Proposition 4. A discussion
on the two independently proposed results is given in Appendix A.2 after the proof. In the
following section, we propose a solution to this problem.
where A is the weighted adjacency matrix of the directed acyclic graph (DAG) upon the k
elements of z (i.e., Aij 6= 0 if and only if [z]i is the parent of [z]j ), denotes the exogenous
variables following N (0, I), f and h are element-wise transformations that are generally
nonlinear, and β = (f, h, A) denotes the set of parameters of f , h and A, with the parameter
space B. Further let IA = I(A 6= 0) denote the corresponding binary adjacency matrix,
where I(·) is the element-wise indicator function.
When f is invertible, (4) is equivalent to
9
Shen, Liu, Dong, Lian, Chen, and Zhang
which indicates that the factors z satisfy a linear SCM after nonlinear transformation f ,
and enables interventions on latent variables as discussed later.
By combining the above SCM prior and the encoder and generator introduced in Sec-
tion 3.1, we end up with the model structure presented in Figure 1. Note that different
from our model where z is the latent variable following the prior (4) with the goal of causal
disentanglement, Yu et al. (2019) propose a causal discovery method where variables z in
SCM (4) are observed with the aim of learning the causal structure among z.
4.1.2 Learning of A
In causal structure learning, the graph is required to be acyclic. Traditional causal discovery
methods such as PC (Spirtes et al., 2000) or GES (Chickering, 2002) deal with the combi-
natorial problem over the discrete space of DAGs. Recently, Zheng et al. (2018) proposed
an equality constraint whose satisfaction ensures acyclicity and solved the problem with the
augmented Lagrangian method, which however leads to optimization difficulties (Ng et al.,
2020). In addition, identifiability of the causal structure from purely observational data is
known as an important issue in causal discovery. Despite a number of results on structure
identifiability under various parametric or semi-parametric assumptions (Zhang and Hy-
varinen, 2009; Peters and Bühlmann, 2014), in a general nonparametric setting, however,
it cannot be guaranteed. Yu et al. (2019) did not discuss the identifiability of the SCM (4)
under general cases.
In many problems of disentanglement, we have some prior information on the causal
structure of the factors of interests based on common knowledge or expertise. In particular,
we may know a causal ordering of the factors. In addition to the ordering, for some factors,
we may know that one particular factor cannot be a direct cause of another one, which helps
us remove some redundant edges in advance. Therefore, in this paper with the focus on
disentanglement, we utilize such prior information on the graph structure in disentanglement
learning and leave incorporating causal discovery from scratch to future work. Formally,
we assume the super-graph of the true binary graph IA0 is given, the best case of which is
the true graph while the worst is that only the causal ordering is available. Then we learn
the weights of the non-zero elements of the prior adjacency matrix that indicate the sign
and scale of causal effects, jointly with other parameters of the generative model using the
formulation and algorithm described in Sections 4.2 and 4.3.
As discussed in Section 4.2, such prior knowledge makes the structure identifiability easy
to hold. Moreover, the given super-graph ensures the acyclicity of the adjacency matrix,
10
Weakly Supervised Disentangled Generative Causal Representation Learning
allowing us to get rid of the additional acyclicity constraint. In Section 5.3, we investigate
how our method performs in learning the graph structure and weighted adjacency given
various amounts of prior graph information. Note that even when a super-graph is available,
to our best knowledge, no previous disentanglement method except GraphVAE (He et al.,
2018) can utilize them to disentangle causal factors with guarantee, but we propose one such
method and show its effectiveness. In fact, He et al. (2018) also assumed an ordering over
the latent nodes by specifying that the parents of node zi , i = 1, . . . , k − 1 come from the
set {zi+1 , . . . , zk }. Later experiments suggest that GraphVAE shows inferior performance
compared with ours.
One immediate application of our proposed model is causal controllable generation from
interventional distributions of the latent variables. We now describe the mechanism. To
enable intervention under SCM (5), we require f to be invertible. Then interventions can
be formalized as operations that modify a subset of equations in (5) (Pearl et al., 2000).
Suppose we would like to intervene on the i-th dimension of z, i.e., Do([z]i = c), where
c is a constant. Once we obtain the latent factors z inferred from data x, i.e., z = E(x), or
sampled from prior pz , we follow the modified equations in (5) to obtain z 0 on the left-hand
side using ancestral sampling by performing (5) iteratively, where can be either fixed
or resampled from its prior. Then we decode the latent factor z 0 that follows the given
interventional distribution to generate the desired sample G(z 0 ). In Section 5.1 we define
the two types of interventions of most interests in applications. We discuss how our method
generalizes to unseen interventions in Appendix D.
Another issue of the model is how to set the latent dimension k of the generative model,
to handle which we propose the so-called composite prior. Recall that m is the number
of generative factors that we are interested to disentangle, for example, all the semantic
concepts related to some filed, where m tends to be smaller than the total number M of
generative factors. The latent dimension k should be no less than M to allow a sufficient
degree of freedom in order to generate or reconstruct data well. Since M is generally
unknown in reality, we set a sufficiently large k, at least larger than m which is a trivial
lower bound of M .
Then we propose to use a prior that is a composition of a causal model for the first m
dimensions and another distribution for the other k − m dimensions to capture other factors
necessary for generation, like a standard Gaussian. In this way the first m dimensions of z
aim at learning the disentangled representation of the m factors of interests, while the role of
the remaining k − m dimensions is to capture other factors that are necessary for generation
whose structure is neither cared nor explicitly modeled. Under this model framework, we
do not require the availability of annotated labels for all generative factors of data, but only
the ones of our interests to disentangle are used in the supervised regularizer in (3), which
broadens the applications of our method.
11
Shen, Liu, Dong, Lian, Chen, and Zhang
Lgen (E, G, F ) = DKL (qE (x, z), pG,F (x, z)). (6)
Then we propose the following formulation to learn disentangled generative causal rep-
resentations:
min L(E, G, F ) := Lgen (E, G, F ) + λLsup (E). (7)
E,G,F
Note that Proposition 5 states the identifiability at the population level, i.e., the loss
function is taken the expectation over distributions of both the data and labels of the true
factors. Thus we clarify that Proposition 5 does not obtain general provable disentan-
glement which should be analyzed with a much weaker form of supervision on the true
factors, e.g., as in Khemakhem et al. (2020). In contrast, the specific identifiability stated
in Proposition 5 should be interpreted as a counterpart of the unidentifaibility result in
Proposition 4. Specifically, Proposition 4 shows that the independent prior used by most
existing disentanglement methods causes the contradiction between the generative loss Lgen
and the supervised loss Lsup in (3), which makes the whole loss L prefer an entangled
model. Therefore, even with the same amount of supervised labels of true factors, those
methods cannot learn a generative model with disentangled latent representations. In con-
trast, Proposition 5 formally suggests that due to the introduction of the SCM prior, the
two loss terms Lgen and Lsup in (7) can be simultaneously minimized and the jointly optimal
solution leads to the disentangled model.
4.3 Algorithm
In this section, we propose the algorithm to solve the above formulation (7). Estimating
Lgen requires the unlabeled data set {x1 , . . . , xN } with sample size N , while estimating Lsup
requires a labeled data set {(xj , yj ) : j = 1, . . . , Ns }, where the sample size Ns can be much
12
Weakly Supervised Disentangled Generative Causal Representation Learning
smaller than N . Without loss of generality, let SG = {x1 , . . . , xN , y1 , . . . , yNs } denote the
training data set for the generative model.
We parametrize Eφ (x) and Gθ (z) by neural networks. As mentioned in Section 3.1, to
enhance the expressiveness of the generative model, we use an implicit generated conditional
pG (x|z), where we inject Gaussian noises to each convolution layer in the same way as Shen
et al. (2020). Then the SCM prior pF (z) and implicit pG (x|z) make (6) lose an analytic form.
Hence we adopt a GAN method to adversarially estimate the gradient of (6) as in Shen
et al. (2020). Different from their setting, the prior also involves learnable parameters, that
is, the parameters β of the SCM. In the following lemma we present the gradient formulas
of (6).
Lemma 6 Let D∗ (x, z) = log[qE (x, z)/pG,F (x, z)]. Then we have
∇θ Lgen = −Ez∼pβ (z) [s(x, z)∇x D∗ (x, z)> |x=Gθ (z) ∇θ Gθ (z)],
∇φ Lgen = Ex∼qx [∇z D∗ (x, z)> |z=Eφ (x) ∇φ Eφ (x)], (8)
x=G(F ())
∇β Lgen = −E [s(x, z)(∇x D∗ (x, z)> ∇β G(Fβ ()) + ∇z D∗ (x, z)> ∇β Fβ ())|z=Fβ ()β ],
∗ (x,z)
where s(x, z) = eD is the scaling factor.
Since D∗ depends on the unknown densities, which makes the gradients in (8) uncom-
putable directly from data, we estimate the gradients by training a discriminator D via the
empirical logistic regression:
X
1 −D0 (xi ,zi )
X
D0 (xi ,zi )
min log(1 + e )+ log(1 + e ) , (9)
D0 Nd
i:wi =1 i:wi =0
4.4 Consistency
In this section, we show the asymptotic convergence of Algorithm 1. Let θ = (θ, φ, β) denote
the set of parameters of the generative model, where θ, φ and β denote the parameters of
the generator, encoder and SCM prior respectively. According to such parametrization, we
write the objective function in (7) as L(θ). In this section, we establish the consistency
result of empirical estimator θ̂, i.e., the output of Algorithm 1, under the parametric setting.
Given a discriminator D, the approximate gradient used in the algorithm is denoted by
We first show in the following lemma that under appropriate conditions the approximate
gradient hD̂ (θ) based on the solution of (9) converges uniformly in probability to the true
13
Shen, Liu, Dong, Lian, Chen, and Zhang
gradient. Recall the definition D∗ (x, z) = log(qE (x, z)/pG,F (x, z)) which depends on θ.
Let D∗ = {Dθ∗ (x, z) : θ ∈ Θ} denote the true discriminator class, R and D = {D(x, z)}
∗
denote the modeled discriminator class with the norm kDk1 = |D(x, z)|pθ (x, z)dxdz,
where p∗θ (x, z) = (qE (x, z) + pG,F (x, z))/2 which induces the probability measure µ∗θ .
Lemma 7 Assume the parameter space Θ = {θ = (θ, φ, β)} is compact. Assume the
following regularity conditions hold:
C1 Dθ∗ is smooth with respect to θ over Θ, as defined in Definition 1.
C2 The modeled discriminator class D is compact, and contains the true class D∗ .
C3 {µ∗θ : θ ∈ Θ} is uniformly tight, i.e., for any > 0, there exists a compact subset K
of X × Z such that for all θ ∈ Θ, µ∗θ (K ) ≥ 1 − .
C4 Functions in D have uniformly bounded function values, gradients and Hessians so that
there exists a positive number B0 < ∞ such that ∀D ∈ D, ∀x, z, we have |D(x, z)| ≤
B0 , k∇D(x, z)k ≤ B0 and |tr(∇2 D(x, z))| ≤ B0 .
C5 Ēφ , ∇Gθ , ∇Eφ and ∇Fβ are uniformly bounded.
C6 The training set for the discriminator is independent from that for the generative
model.
Then there exists a sequence of (N, Ns , Nd ) → ∞ such that
p
sup khD̂ (θ) − ∇L(θ)k → 0, (10)
θ∈Θ
p
where → means converging in probability.
Based on the above, we obtain the consistency of DEAR algorithm in the following
theorem. It indicates that when the sample sizes grow large enough, with high probability,
the DEAR algorithm approximately achieves the minimum of L(θ) which leads to the
desired disentangled model according to Proposition 5.
14
Weakly Supervised Disentangled Generative Causal Representation Learning
Remark. The Polyak-Lojasiewicz (PL) condition (Polyak, 1963) asserts that the subopti-
mality of a model is upper bounded by the norm of its gradient, which is a weaker condition
than assumptions commonly made to ensure convergence, such as (strong) convexity. Re-
cent literature showed that the PL condition holds for many machine learning scenarios
including some deep neural networks (Charles and Papailiopoulos, 2018; Liu et al., 2020).
5. Experiments
We present the experimental studies in causal controllable generation in Section 5.1 which
demonstrate the effectiveness of DEAR in causal disentanglement and support the theory
in Section 4. Based on these theoretical and empirical justifications, we then apply the
representations learned by DEAR in downstream prediction tasks in Section 5.2, and show
the benefits of the disentangled causal representations in terms of sample efficiency and
distributional robustness. In addition, we investigate the performance of DEAR in learning
the causal structure and weighted adjacency of the SCM prior in Section 5.3. We also
provide ablation studies in terms of varying regularization strength λ and various amounts
of annotated labels in Section 5.4.2
We evaluate our methods on two data sets where the ground-truth generative factors are
causally related, while most data sets used in previous disentanglement work are assumed or
designed to have independent generative factors, for example, in the large scale experimental
study by Locatello et al. (2019). The first data set that we use is a synthesized data
set, Pendulum, similar to the one in Yang et al. (2021). As shown in Figure 3, each
image is generated by four continuous factors: pendulum angle, light angle, shadow length
and shadow position whose underlying structure is given in Figure 2(a) following physical
mechanisms. To make the data set realistic, we introduce random noises when generating
the two effects from the causes, representing the measurement error. We further introduce
20% corrupted data whose shadow is randomly generated, mimicking some environmental
disturbance. The sample sizes for the training, validation and test set are all 6,724.
The second one is a real human face data set, CelebA (Liu et al., 2015), with 40 labeled
binary attributes. Among them, we consider two groups of causally related factors of
interests as shown in Figure 2(b,c). The sample sizes for the training, validation and test set
are 162,770, 19,867, and 19,962. We believe these two data sets are diverse enough to assess
our methods because they cover real and synthesized data, with continuous and discrete
annotated labels. In addition, we test our method on benchmark data sets (Gondal et al.,
2019) where the generative factors are independent. The results are given in Appendix E.
All the details of the experimental setup, network architectures and the synthesized data set
are given in Appendix F. Notably, all VAEs and DEAR use the same network architecture
for the encoder and decoder (generator).
15
Age6
young(1) gender(2)
pendulum_angle(1) light_angle(2)
eye_bag(6)
Shen, Liu, Dong, Lian, Chen, and Zhang
chubby(5) make_up(4) receding_hairline(3)
shadow_length(3) shadow_position(4)
smile6
cheek-
bone(3) eye_bag(6)
shadow_ shadow_ shadow_
mouth_ shadow_ narrow_ make_ receding_
length(3) position(4) length(3)
open(4) position(4) eye(5)
chubby(6) chubby(5)
up(4) hairline(3)
5.1 Causalcheckbone(3)
controllablemouth_open(4)
generation narrow_eye(5)
16
Weakly Supervised Disentangled Generative Causal Representation Learning
Traverse a
single
latent with
others fixed
Single factor
Multiple affected
factors
affected
Disentangled
Figure 3: Results in causal controllable generation on Pendulum. For example, in line 1 of (a,b)
when changing the first dimension [z]1 of z which is supervised with the annotated label
of pendulum angle while keeping the others fixed, we see that the traversals of DEAR
vary only in pendulum angle (disentanglement), while those of S-β-VAE vary in both
pendulum angle and shadow length (entanglement); in line 3 when changing [z]3 with
the others fixed, only shadow length is affected with DEAR but both shadow length and
pendulum angle are affected with S-β-VAE. In line 1 of (d) we see the intervening on
pendulum angle affects its effects shadow length and shadow position, which is consistent
with the desired interventional distribution.
smile
gender
Traverse a
single cheekbone
latent with
Single factor
others fixed
affected
mouth
Multiple _open
factors Disentangled
affected
narrow
_eye
No factor
affected
chubby
Intervene narrow_eye
on gender affected
Figure 4: Results in causal controllable generation on CelebA. For example, in line 1 of (a,b) when
altering [z]1 with the others fixed, we see that the traversals of DEAR vary only in a single
factor smile with factor mouth open unaffected, while S-β-VAE entangles the two factors.
In line 5-6 of (a), when changing [z]5 and [z]6 which are supervised with narrow eye and
chubby, no factors seem to be affected, indicating that the S-β-VAE fails to learn the
representations of some factors. In line 1 of (d) we see that intervening on smile affects
its effect mouth open, which makes sense.
17
Shen, Liu, Dong, Lian, Chen, and Zhang
In each figure, we first infer the latent representations from a test image in block (c).
The traditional traversals of the two models are given in blocks (a,b). We see that in each
line when manipulating one latent dimension while keeping the others fixed, the generated
images of our model vary only in a single factor, indicating that our method can disentangle
the causally related factors, while those of S-β-VAE show multiple factors affected. It is
worth pointing out that we are the first to achieve the disentanglement between a cause
factor and its effects, while other methods tend to entangle them. One typical example
is the disentanglement between smile and its effect mouth open as shown in Figure 4. In
block (d), we show the results of intervention on the latent variables representing the cause
factors, which clearly show that intervening on a cause variable changes its effect variables.
Results in Appendix G further show that intervening on an effect variable does not influence
its cause. Specific examples are given in the captions. Note that without an SCM prior,
S-β-VAE cannot generate data from general interventional distributions. More qualitative
traversals from DEAR are given in Appendix G.
18
Weakly Supervised Disentangled Generative Causal Representation Learning
tional supervision of the generative factors, we consider another baseline ResNet50 (named
ResNet-pretrain) that is pretrained using multi-label classification to predict the factors
on the same training set. Unless indicated otherwise, DEAR, S-VAEs, S-GraphVAE, and
ResNet-pretrain have access to the annotated labels for all training samples, and DEAR
and S-GraphVAE are given the true graph structure. We provide the detailed results when
there is less supervised information on labels and the graph structure in Sections 5.4 and
5.3.
To measure the sample efficiency, we use the statistical efficiency score defined as the
average test accuracy based on 100 samples divided by the average accuracy based on
10,000/all samples, following Locatello et al. (2019). Note that this metric may be mislead-
ing when a method always achieves poor accuracy with small and large training samples.
Therefore, we also report the test accuracies with different training sample sizes to provide
a comprehensive evaluation.
Table 1 presents the results, showing that DEAR owns the highest sample efficiency
and test accuracy on both data sets. ResNet with raw data inputs has the lowest efficiency,
although multi-label pretraining improves its performance to a limited extent. S-VAEs have
better efficiency than the ResNet baselines but lower accuracy under the case with more
training data. Since the encoders of all S-VAEs and DEAR share the same architecture, we
explain the inferior performance of S-VAEs is mainly because the independent prior contra-
dicts with the supervised loss as indicated in Proposition 4, making the learned representa-
tions entangled (as shown in the previous section) and less informative. On the Pendulum
data with few underlying factors, S-GraphVAE outperforms the S-VAEs when training on
a smaller sample, indicating that an SCM latent structure has advantages over the inde-
pendent structure under the VAE framework. Nevertheless, even with the same amount of
supervision (on both annotated labels and the same given graph structure), S-GraphVAE
is still inferior to DEAR, potentially due to our better causal modeling and optimization
based on a GAN algorithm. On the more complex data set CelebA, S-GraphVAE gives very
poor performance, even worse than S-VAEs and ResNet.
In addition, we investigate the performance of DEAR under the semi-supervised setting
where only 10% of the labels are available. We find that DEAR with fewer labels has
comparable sample efficiency with that in the fully supervised setting, with a sacrifice in
the accuracy that is yet still comparable to other baselines which use much more supervision.
In Section 5.4, we provide ablation studies to show how DEAR behaves in terms of varying
amounts of labeled samples and different choices of the regularization strength λ.
We also study knowing less prior information on the causal graph structure. In the last
two lines of Table 1, DEAR-SG stands for the DEAR-LIN model trained with a given super-
graph (which is not a full graph) of the true graph and DEAR-O stands for the DEAR-LIN
model trained with a known causal ordering. We see that DEAR-SG leads to comparable
performance as DEAR with the known graph structure, while DEAR-O is slightly worse
but still competitive compared with other baseline methods. As we will show later, on
Pendulum, DEAR-O can recover the true structure and the performance in downstream
tasks is identical to that of DEAR given the true structure, so we skip showing the last
two lines in Table 1(b). In Section 5.3, we investigate the performance in learning the SCM
and in particular, the causal structure, given various amounts of prior information about
19
Shen, Liu, Dong, Lian, Chen, and Zhang
Table 1: Sample efficiency and test accuracy with different training sample sizes. DEAR-
LIN and -NL denote the DEAR models with linear and nonlinear f respectively.
the true graph, where more insights are given to explain the comparable performance of
DEAR-SG in downstream tasks.
20
Weakly Supervised Disentangled Generative Causal Representation Learning
robust. Baseline methods include ERM, multi-label ERM which is trained to predict both
target label and the factors considered in disentanglement in order to have the same amount
of supervision, S-VAEs that are shown unable to disentangle well in the causal case, and
S-GraphVAE.
Table 2 presents the average and worst-case test accuracy to assess both the overall
classification performance and distributional robustness. The worst-case (Sagawa et al.,
2019) accuracy refers to the following: we group the test set according to the two binary
labels, the target one and the spurious attribute, into four cases and regard the group with
the worst accuracy as the worst-case, which usually owns the opposite spurious correlation
to the training data. It can be seen that the classifiers trained upon DEAR representations
significantly outperform the baselines in both metrics. Particularly, when comparing the
worst-case accuracy with the average one, we observe a slump from around 80 to around
60 for other methods on CelebA, while DEAR enjoys a much smaller decline. As in sample
efficiency, S-GraphVAE suffers from a smaller drop in worst-case accuracy than S-VAEs
on Pendulum, but remains inferior to DEAR. On CelebA, S-GraphVAE again shows poor
performance.
Moreover, with fewer annotated samples (i.e., 10% of the full sample), DEAR-10% re-
mains competitive against baseline methods which use even more supervised labels. DEAR-
SG (given the super-graph) is slightly better than DEAR-O (given the ordering), both of
which are comparable to DEAR given the full structure. More ablation studies in terms
of the labeled proportion as well as the strength of the supervised regularizer are given in
Section 5.4.
21
Shen, Liu, Dong, Lian, Chen, and Zhang
light_angle(2)
eye_bag(6) eye_bag(6)
eye_bag(6)
eye_bag(6)
receding_ receding_
make_receding_
make_receding_ chubby(5)make_
shadow_ shadow_ shadow_ chubby(5)make_
chubby(5) hairline(3) chubby(5) up(4) hairline(3)
hairline(3)
length(3) position(4) up(4)up(4)hairline(3) up(4)
position(4)
Figure 6: The given causal structures. -O and -SG stand for the causal ordering and super-graph.
The black edges are true and red edges are in fact redundant.
graph over the underlying factors of interests. The experiments shown in previous sections
are all based on the given true binary structure IA0 . Here we investigate the performance in
learning the causal structure on knowing various amounts of information about the graph,
which ranges from the causal ordering to the true structure. Note that the adjacency
matrices learned by DEAR-LIN and DEAR-NL are consistent up to some scaling, so in this
section we only show the results from DEAR-LIN as a representative.
Figure 5 shows the learned weighted adjacency matrices when the true binary structure
is given for the three underlying structures shown in Figure 2. It can be seen that the
weights exhibit meaningful signs and scalings that are consistent with common knowledge.
For example, the factor smile and its effect mouth open are positively correlated, that is,
one is more likely to open mouth when smiling. The corresponding element in the weighted
adjacency A14 of (b) turns out positive, which makes sense. Also gender (the indicator of
being male) and its effect make up are negatively correlated, that is, women tend to make
up more often than men. Correspondingly, element A24 of (c) turns out negative.
Next, we evaluate the performance of DEAR in structure learning with less prior knowl-
edge on the true graph, i.e. knowing a super-graph rather than the exact true graph. We
first study on the synthetic data set Pendulum whose ground-truth structure is shown in
Figure 2(a), where there are fewer causal factors and no hidden confounder. Consider the
causal ordering pendulum angle, light angle, shadow position, shadow length, given which
we start with a full graph (shown in Figure 6(a)) represented by an upper triangular ad-
jacency matrix whose elements are randomly initialized around 0 (shown in Figure 7(a)).
Figure 7(a-d) present the weighted adjacency matrices learned by DEAR at different train-
ing epochs. We observe that the weights of the two redundant edges A12 and A34 vanish
22
Weakly Supervised Disentangled Generative Causal Representation Learning
(a) Epoch 0 (b) Epoch 100 (c) Epoch 200 (d) Epoch 500 (e) S-GraphVAE
Figure 7: Learned weighted adjacency matrices on Pendulum given the causal ordering. (a-d)
are the learned matrices from DEAR at different training epochs starting from random
initialization around 0, and (e) is the result from S-GraphVAE.
gradually and it eventually leads to the weighted adjacency that nearly coincides with the
one learned given the true graph shown in Figure 5(a). In contrast, Figure 7(e) shows the
structure learned by S-GraphVAE. Note that GraphVAE learns a binary structure with 0-1
elements and (e) shows the learned probabilities of each element being 1. We see that it
learns a redundant edge A12 from pendulum angle to light angle and misses the edge A23
from light angle to shadow position. This experiment shows the advantage of DEAR over
GraphVAE in learning the latent causal structure.
(a) Epoch 0 (b) Epoch 5 (c) Epoch 50 (d) Epoch 150 (e) S-GraphVAE
Figure 8: Learned weighted adjacency matrices on CelebA given a super-graph. (a) represents a
random initialization around 0 of the weighted adjacency matrix corresponding to the
super-graph in Figure 6(b); (b-d) are the learned matrices by DEAR at different training
epochs; (e) is the result from S-GraphVAE.
The case is more complicated on the real data set CelebA. Although the number of
factors of interests, six, is not large, there are much more underlying generative factors.
Some of the other factors that we are not interested to disentangle could serve as the hidden
confounders of the factors that we are interested in. For example, staying up late may cause
a person to have eye bags and look chubby and hence serves as a hidden confounder of the
two factors eye bag and chubby in Figure 2(c). These hidden confounders can be captured
in the remaining dimensions of the learned representations through the composite prior
introduced in Section 4.1.4. However, their existence makes it difficult to identify and learn
the structure of the factors of interest. Another complication comes from some biases in
the data, potentially caused by selection bias or unknown interventions. Such biases may
result in spurious correlations even among the causal variables, bringing trouble to causal
structure learning. There are orthogonal works (e.g., Ke et al., 2019; Bengio et al., 2020;
23
Shen, Liu, Dong, Lian, Chen, and Zhang
(a) Epoch 0 (b) Epoch 5 (c) Epoch 50 (d) Epoch 150 (e) S-GraphVAE
Figure 9: Learned weighted adjacency matrices on CelebA given the causal ordering. (a-d) are the
learned matrices by DEAR at different training epochs starting from random initialization
around 0; (e) is the result from S-GraphVAE.
Brouillard et al., 2020) focusing on causal discovery under hidden confounders or unknown
interventions, which however is beyond the scope of this paper and will be systematically
explored in future work. Here we only provide some empirical studies to evaluate our
method under this complicated case.
We conduct two experiments on CelebA. In the first one, we assume knowing a super-
graph (Figure 6(b)) of the true graph (Figure 2(c)) and randomly initialize its weighted
adjacency matrix around 0 as in Figure 8(a). Then Figure 8(a-d) show the weighted ad-
jacency matrices learned by DEAR at different training epochs. Similar to the previous
experiment on Pendulum, the weights corresponding to the redundant edges gradually van-
ish. Eventually, DEAR learns the weighted adjacency matrix that largely agrees with the
one learned given the true graph shown in Figure 5(c). After edge pruning, one can essen-
tially recover the true graph structure. This explains why DEAR-SG (the DEAR model
given this super-graph) performs competitively with DEAR given the true structure in the
downstream tasks in the previous two sections. In contrast, the graph learned by Graph-
VAE shown in Figure 8(e) fails to recover the true structure, although it is given the same
known super-graph as DEAR.
In the second experiment, we only assume knowing the causal ordering which leads to
a full graph shown in Figure 6(c) with the upper-triangular weighted adjacency matrix
randomly initialized in Figure 9(a). We observe that although DEAR can remove most of
the redundant edges, it mistakenly learns a large weight on the edge from young to gender.
This may be due to the spurious correlation between the two factors young and gender
potentially caused by the selection bias during data collection. In comparison, as shown in
Figure 9(e), the graph learned by GraphVAE given the same causal ordering turns out to be
farther away from the true graph than DEAR. Nevertheless, as discussed in the previous two
sections, DEAR-O (the DEAR model given the causal ordering) still achieves reasonably
satisfying performance, which indicates the robustness of our DEAR method against the
correctness of the learned graph structure.
In summary, when given the true graph structure, DEAR can learn meaningful weights
for each edge. If there is no hidden confounder or spurious correlation among the factors of
interests, DEAR can learn the true graph given only the causal ordering. If there exist such
biases, DEAR can only recover the true structure given some proper super-graphs and in
general cannot learn all edges correctly when only the causal ordering is given. In all cases,
DEAR outperforms GraphVAE in learning the causal structure.
24
Weakly Supervised Disentangled Generative Causal Representation Learning
In this section, we conduct ablation studies to illustrate how DEAR performs when using
different choices of the hyperparameter λ which determines the weight of the supervised
regularizer and varying amounts of labeled samples. According to Proposition 5 and Theo-
rem 8, at the population level, i.e., assuming an infinite amount of data, the regularization
strength λ in the objective (7) can be any arbitrary positive value to make the theorems
hold. However, in practice with a finite sample, λ cannot be arbitrarily small roughly due to
the estimation error. Therefore we suggest regarding λ as a hyperparameter and investigate
its sensitivity across different tasks and data sets. Figures 10-11 plot the metrics in sample
efficiency and distributional robustness when using different choices of λ. We observe that
all these results (with λ ranging from 0.1 to 10) remain significantly superior to the baseline
methods in Tables 1-2, which suggests that DEAR can perform reasonably well across a
wide range of λ. As λ becomes close to 0, we generally observe a performance decrease.
Next, we study how DEAR, as well as baseline methods, behave as we reduce the number
of annotated samples. Figures 12-13 plot the metrics in sample efficiency and distributional
robustness when using different amounts of labeled samples. Note that 0.1% of the CelebA
training set corresponds to 162 samples and 1% of the Pendulum training set corresponds
to 67 samples, both of which belong to weakly supervised settings according to Locatello
et al. (2020b). Such small numbers of supervised labels belong to weakly supervised settings
according to Locatello et al. (2020b) and would make manual labeling feasible even if no
label is available beforehand. Naturally, with fewer labeled samples, all methods basically
perform worse. DEAR always outperforms the VAEs. In particular, as shown in Figure
13(a), when training with 0.1%-1% labels of the CelebA training sample, S-β-VAE and
S-TCVAE completely fail in the worst-case group, meaning that the classifiers trained upon
them almost fully rely on the spurious correlation and exhibit no robustness to distribution
shifts at all. In Figure 12(a), when the supervised proportion is lower, although S-β-VAE
and S-TCVAE have higher sample efficiency, they actually perform poorly with both small
and large samples, leading to a misleadingly high efficiency score.
0.845 0.98
Small sample accuracy
0.995
0.840 0.90
0.990 0.97
Efficiency
Efficiency
0.835 0.89
DEAR−LIN 0.96 DEAR−LIN
0.985
0.830 DEAR−NL DEAR−NL
0.88
0.95
0.825 0.980
0.87
0.94
0.820 0.975
0.0 2.5 5.0 7.5 10.0 0.0 2.5 5.0 7.5 10.0 0.0 2.5 5.0 7.5 10.0 0.0 2.5 5.0 7.5 10.0
λ λ λ λ
Figure 10: Test accuracy when training on a small sample & sample efficiency, as defined in Sec-
tion 5.2.1, against four different choices of λ: 0.1, 1, 5, and 10.
25
Shen, Liu, Dong, Lian, Chen, and Zhang
0.76
0.84 0.935
0.75
0.74 0.930
WorstAcc
WorstAcc
AvgAcc
AvgAcc
0.72 0.83
DEAR−LIN 0.925 DEAR−LIN
0.69
DEAR−NL 0.72 DEAR−NL
0.82 0.920
Figure 11: Worst-case and average test accuracy, as defined in Section 5.2.2, against different
choices of λ. On Pendulum, we experiment with λ = 0.1, 1, 5, 10; on CelebA, we exper-
iment with λ = 0.01, 0.1, 1, 5, 10.
0.85 0.98
Small sample accuracy
Efficiency
0.950 DEAR−NL DEAR−NL
0.75 0.94
S−beta−VAE 0.84 S−beta−VAE
0.925 S−GraphVAE S−GraphVAE
0.92
0.70 S−TCVAE S−TCVAE
0.900 0.80
0.90
0.65
0.875
0.001 0.01 0.1 1 0.001 0.01 0.1 1 0.01 0.1 1 0.01 0.1 1
Proportion of labeled samples Proportion of labeled samples Proportion of labeled samples Proportion of labeled samples
Figure 12: Test accuracy with a small training sample & sample efficiency against different propor-
tions of labeled samples among full data. On the larger data set CelebA, we consider
proportion=0.001, 0.01, 0.1, 1; on the smaller Pendulum data, we consider 0.01, 0.1, 1.
6. Conclusion
In this paper, we showed that previous methods with the independent latent prior assump-
tion fail to learn disentangled representation when the underlying factors of interests are
causally related. We then proposed a new disentangled learning method called DEAR
with theoretical guarantees for identifiability and asymptotic consistency. Extensive ex-
periments demonstrated the effectiveness of DEAR in causal controllable generation and
structure learning, and the benefits of the learned representations for downstream tasks.
Several future directions are worth exploring. Although in our ablation experiments,
we demonstrated that DEAR exhibits promising performance in weakly supervised settings
in terms of annotated labels and the graph structure, it is worth considering more flexible
forms of supervision to make DEAR widely adopted in more real-world applications. On
one hand, regarding the annotated labels of the factors of interests, one may consider
utilizing other forms of supervision, such as restricted labeling or rank pairing (Shu et al.,
2020). Besides, instead of using direct supervision about the true factors, one may consider
some additionally observed variables such as class labels or time index (Khemakhem et al.,
2020) which serve as auxiliary information to ensure more general identifiability of the true
latent factors in the causal case. On the other hand, regarding the graph structure, our
experiments in Section 5.3 indicated the potential of DEAR in latent structure learning.
As in many real applications, even the causal ordering may not be available, it is promising
26
Weakly Supervised Disentangled Generative Causal Representation Learning
0.8 0.84
0.7
0.6 0.80 0.90
DEAR−LIN DEAR−LIN
WorstAcc
WorstAcc
0.6
AvgAcc
AvgAcc
DEAR−NL DEAR−NL
0.4 0.76 S−beta−VAE S−beta−VAE
S−GraphVAE 0.5 0.85 S−GraphVAE
S−TCVAE S−TCVAE
0.2 0.72
0.4
0.80
0.001 0.01 0.1 1 0.001 0.01 0.1 1 0.01 0.1 1 0.01 0.1 1
Proportion of labeled samples Proportion of labeled samples Proportion of labeled samples Proportion of labeled samples
Figure 13: Worst-case and average test accuracy against different proportions of labeled samples
among full data.
to incorporate causal discovery methods in the DEAR framework to learn the latent causal
structure from scratch (i.e., without any prior information) with a guarantee of the structure
identifiability.
In addition, the proposed method applies to the case where the observational data are
IID, as commonly considered in the literature of generative models and disentanglement. It
would be interesting to extend the current approach to non-IID settings, in particular, to
the scenarios where one can perform interventions during data collection. For example, in
reinforcement learning, the interactive environment allows the agent to perform actions and
observe their outcomes. The resulting data set that contains a mixture of interventional
distributions (e.g., Ke et al., 2021) could be leveraged in causal disentanglement learning.
Acknowledgments
We would like to thank the anonymous reviewers for their valuable comments that were
very useful for improving the quality of this work. The work was supported by the General
Research Fund (GRF) of Hong Kong (No. 16201320). F. Liu’s research was supported in
part by a Key Research Project of Zhejiang Lab (No. 2022PE0AC04).
27
Shen, Liu, Dong, Lian, Chen, and Zhang
Appendix A. Proofs
A.1 Preliminaries
This section presents some preliminary notions and lemmas which will be used in proofs.
Proof Given any > 0. Because f is uniformly continuous, there exists δ > 0 such that
kf (x) − f (y)k ≤ for all kx − yk ≤ δ.
We have
P sup kTθ (Xn ) − Tθ (X)k ≤ δ = P ∀θ ∈ Θ : kTθ (Xn ) − Tθ (X)k ≤ δ (11)
θ∈Θ
≤ P ∀θ ∈ Θ : kf (Tθ (Xn )) − f (Tθ (X))k ≤
= P sup kf (Tθ (Xn )) − f (Tθ (X))k ≤ . (12)
θ∈Θ
By the uniform convergence of Tθ (Xn ), we know the left-hand side of (11) converges to 1.
Hence (12) goes to 1, which implies the desired result.
Proof Note that assumptions in (∗) satisfy the requirements in the Arzelà-Ascoli theorem.
Thus, for each subsequence of pn , there is a further subsequence pnm which converges
uniformly on compact set K, i.e., for some p0 as m → ∞ we have
28
Weakly Supervised Disentangled Generative Causal Representation Learning
p
By Scheffé’s Theorem we have H(pnm , p0 ) → 0. On the other hand we have H(pnm , p) →
0. By triangle inequality,
p
H(p, p0 ) ≤ H(pnm , p0 ) + H(pnm , p) → 0.
Since the inequality holds for all m and the LHS is deterministic, we have H(p, p0 ) = 0,
which implies p = p0 , a.e. wrt the Lebesgue measure. Hence we have
sup |pnm (x) − p(x)| → 0, a.e.
x∈K
p
By Durrett (2019, Theorem 2.3.2), we have supx∈K |pn (x) − p(x)| → 0 as n → ∞.
29
Shen, Liu, Dong, Lian, Chen, and Zhang
is applicable given the true causal ordering under the basic Markov and causal minimality
conditions (Pearl, 2014; Zhang and Spirtes, 2011).
Assumption 2 For all β = (f, h, A) ∈ B with pβ = pβ0 , it holds that IA = IA0 .
Proof To simplify the notations in this section, for a vector x, let xi denote the i-th
element of x instead of [x]i . For a vector function g(x), let gi (x) denote the i-th component
function.
Assume E is deterministic.
On one hand, for each i = 1, . . . , m, first consider the cross-entropy loss
Lsup,i (E) = E(x,y) [CE(Ei (x), yi )]
Z
= − qx (x)p(yi |x)[yi log σ(Ei (x)) + (1 − yi ) log(1 − σ(Ei (x)))]dxdyi ,
where p(yi |x) is the probability mass function of the binary label yi given x, characterized
by P(yi = 1|x) = E(yi |x) and P(yi = 0|x) = 1 − E(yi |x). Let
Z
∂Lsup,i 1 1
= qx (x)p(yi |x) − yi dxdyi = 0.
∂σ(Ei (x)) 1 − σ(Ei ) σ(Ei )(1 − σ(Ei ))
Then we know that Ei∗ (x) = σ −1 (E(yi |x)) = σ −1 (ξi ) minimizes Lsup,i .
Consider the L2 loss
Z
Lsup,i (φ) = E(x,y) [Ei (x) − yi ]2 = qx (x)p(yi |x)[Ei (x) − yi ]2 dxdyi .
Let Z
∂Lsup,i
=2 qx (x)p(yi |x)(Ei (x) − yi )dxdyi = 0.
∂Ei (x)
Then we know that Ei∗ (x) = E(yi |x) = ξi minimizes Lsup,i in this case.
On the other hand, by Assumption 1 there exists β0 = (f0 , h0 , A0 ) such that pξ = pβ0 .
Further due to the infinite capacity of G and Assumption 1, we have the distribution family
of pG,F (x, z) contains qE ∗ (x, z). Then by minimizing the loss in (7) over G, we can find G∗
and F ∗ such that pG∗ ,F ∗ (x, z) matches qE ∗ (x, z) and thus Lgen (E ∗ , G∗ , F ∗ ) reaches 0, where
F ∗ corresponds to parameter β ∗ = (f ∗ , h∗ , A∗ ).
Note that pG∗ ,F ∗ (x, z) = qE ∗ (x, z) implies that the marginal distributions match, i.e.,
pF ∗ (z) = qE ∗ (z). Generally denote Ei∗ (x) = gi (ξi ) for i = 1, . . . , m. Then, for i = 1, . . . , m,
the distributions of gi−1 (Ei∗ (x)) = ξi and gi−1 (Fi∗ ()) are identical. It can be seen that
pβ0 = pβ0∗ with β0∗ = (g −1 ◦ f ∗ , h∗ , A∗ ), where ◦ denotes elementwise composition. Then
according to Assumption 2, we have IA∗ = IA0 .
Hence minimizing L = Lgen + λLsup , which is the DEAR formulation (7), leads to the
solution with Ei∗ (x) = gi (ξi ) with gi (ξi ) = σ −1 (ξi ) if CE loss is used, and gi (ξi ) = ξi if L2
loss is used, and the true binary adjacency matrix IA0 .
30
Weakly Supervised Disentangled Generative Causal Representation Learning
Ex,z,w∼p∗ (x,z,w) [log pD (x, z, w)] = Ep∗ (x,z,w) log[p∗ (x, z)pD (w|x, z)]
Step I We now establish the consistency of D̂(x, z) to D∗ (x, z) as defined in (14) below
based on the generalization analysis of maximum likelihood estimation.
Let the class
pD (x, z, w) + p∗ (x, z, w)
1
G = g(x, z, w) = log :D∈D .
2 2p∗ (x, z, w)
Note that each element of G can be written as
1 pD (w|x, z) + p∗ (w|x, z)
g(x, z, w) = log .
2 2p∗ (w|x, z)
31
Shen, Liu, Dong, Lian, Chen, and Zhang
Then by continuous mapping theorem (Lemma 10) and noting that l(p) = log(p/(1 − p)) is
uniformly continuous on a closed interval within (0, 1), we have as Nd → ∞
p
sup |D̂(x, z) − D∗ (x, z)| → 0. (14)
(x,z)∈K
Step II We then prove the pointwise consistency of ∇D̂(x, z) to ∇D∗ (x, z) as defined in
(17).
Construct an arbitrary probability measure µ on X × Z that satisfies the following (e.g.,
a Gaussian measure):
sZ Z sZ Z
≤ 2 2 2
|u| dµ [tr(∇ u)] dµ + |u| dµ (∇u> ∇ log ρ)2 dµ.
2 (15)
Recall from condition C4 that there exists a positive number B0 < ∞ such that ∀x, z,
∀D ∈ D, we have |D(x, z)| ≤ B0 , k∇D(x, z)k ≤ B0 and |tr(∇2 D(x, z))| ≤ B0 .
32
Weakly Supervised Disentangled Generative Causal Representation Learning
Given arbitrary > 0, we know from the tightness of µ that there exits a compact subset
K of X × Z such that µ(K ) ≥ 1 − . Let B = max{B0 , B1 }. Then we have for all θ ∈ Θ
that
Z
k∇D̂(x, z) − ∇D∗ (x, z)k2 dµ
X ×Z
sZ s
√ 2 Z
≤ 2B |D̂(x, z) − D∗ (x, z)|2 dµ + 2B |D̂(x, z) − D∗ (x, z)|2 dµ
X ×Z X ×Z
√
sZ Z (16)
2
= (2B + 2B ) |D̂(x, z) − D∗ (x, z)|2 dµ + |D̂(x, z) − D∗ (x, z)|2 dµ
K Kc
sZ
√
≤ (2B + 2B 2 ) |D̂(x, z) − D∗ (x, z)|2 dµ + 2B 2 ,
K
where ν denotes the Lebesgue measure. Since ρ(x, z) > 0, this implies k∇D̂(x, z) −
p
∇D∗ (x, z)k → 0 for all non-extrema. By Lipschitz continuity of v on any compact set,
p
we have k∇D̂(x, z) − ∇D∗ (x, z)k → 0 for all extrema.
Up to now we have shown that for all θ ∈ Θ and (x, z) ∈ X × Z, we have k∇D̂(x, z) −
p
∇D∗ (x, z)k → 0 as Nd → ∞. Further from the smoothness in condition C1 and the
compactness of Θ, we have ∀x, z, as Nd → ∞
p
sup k∇D̂(x, z) − ∇D∗ (x, z)k → 0. (17)
θ∈Θ
33
Shen, Liu, Dong, Lian, Chen, and Zhang
Step III Based on the convergence statements established above, we proceed to show the
consistency of the approximate gradient hD̂ (θ) and complete the proof.
By condition C3 , {µ∗ } is uniformly tight. For arbitrary > 0, there exists a compact
subset K of X × Z such that µ∗ (Kc ) < . Because ∇D(x, z) is Lipschitz continuous with
respect to (x, z) on K , we have as Nd → ∞
p
sup k∇D̂(x, z) − ∇D∗ (x, z)k → 0. (18)
θ∈Θ,(x,z)∈K
P(ANd ) ≥ P(ANd ∩ BN, ) = P(ANd |BN, )P(BN, ) ≥ P(ANd |BN, )(1 − )N → (1 − )N .
sup khD̂ (θ) − ∇L(θ)k ≤ sup khD̂ (θ) − hD∗ (θ)k + sup khD∗ (θ) − ∇L(θ)k.
θ∈Θ θ∈Θ θ∈Θ
p
sup khD̂ (θ) − ∇L(θ)k → 0
θ∈Θ
η 2 `0
L(θt ) ≤ L(θt−1 ) − ηhD̂ (θt−1 )> ∇L(θt−1 ) + h (θt−1 )> hD̂ (θt−1 ).
2 D̂
34
Weakly Supervised Disentangled Generative Causal Representation Learning
Let ˆ(θ) = ∇L(θ) − hD̂ (θ). By Lemma 7, there exists a sequence of (N, Ns , Nd ) → ∞
p
such that ˆ = supθ kˆ
(θ)k → 0. Then we have
−ηhD̂ (θt−1 )> ∇L(θt−1 ) = −ηhD̂ (θt−1 )> hD̂ (θt−1 ) + ˆ(θt−1 )
η η 2 `0
L(θt ) ≤ L(θt−1 ) − khD̂ (θt−1 )k2 + khD̂ (θt−1 )k2
4 2
η
≤ L(θt−1 ) − khD̂ (θt−1 )k2 ,
8
when η < 1/4`0 which can be satisfied with a sufficiently small learning rate.
By summing over t = 1, . . . , T , we have
T
X
L(θT ) ≤ L(θ0 ) − 0.125η khD̂ (θt−1 )k2 .
t=1
Note that L(θ) is lower bounded by 0. Then we have t khD̂ (θt−1 )k2 = O(1). Thus there
P
exists t in {0, . . . , T } such that khD̂ (θt−1 )k2 = O(1/T ).
√
Otherwise there exists t such that khD̂ (θt−1 )k < 2ˆ = op (1).
p
Therefore we have the empirical estimator khD̂ (θ̂)k → 0.
By the uniform convergence (10) from Lemma 7, we have k∇L(θ̂)k = 0. Then by the
PL condition, there exists a sequence of (N, Ns , Nd ) → ∞ such that
p
L(θ̂) − L∗ → 0,
35
Shen, Liu, Dong, Lian, Chen, and Zhang
for all vector function g(x) such that g(∞) = 0. Given a matrix function w(x) = (w1 (x), . . . , wl (x)) :
Rk → Rk×l where each wi (x), i = 1 . . . , l is a k-dimensional differentiable vector function,
its divergence is defined as ∇ · w(x) = (∇ · w1 (x), . . . , ∇ · wl (x)).
To prove Lemma 6, we need the following lemma which specifies the dynamics of the
generator joint distribution pg (x, z) and the encoder joint distribution pe (x, z), denoted by
pθ (x, z) and pφ (x, z) here.
Lemma 13 Using the definitions and notations in Lemma 6, we have
∇θ pθ,β (x, z) = −∇x pθ,β (x, z)> gθ (x) − pθ,β (x, z)∇ · gθ (x), (19)
∇φ qφ (x, z) = −∇z qφ (x, z)> eφ (z) − qφ (x, z)∇ · eφ (z), (20)
˜
fβ (x)
∇β pθ,β (x, z) = ∇x pθ,β (x, z)> f˜β (x) − ∇z pθ,β (x, z)> fβ (z) − pθ,β (x, z)∇ · , (21)
fβ (z)
for all data x and latent variable z, where gθ (Gθ (z, )) = ∇θ Gθ (z, ), eφ (Eφ (x, )) =
∇φ Eφ (x, ), fβ (Fβ ()) = ∇β Fβ (), and f˜β (G(Fβ ())) = ∇β G(Fβ ()).
Proof [Proof of Lemma 13] We only prove (21) which is the distinct part from Shen et al.
(2020).
Let l be the dimension of parameter β. To simplify notation, let random vector Z =
Fβ () and X = G(Z) ∈ Rd and Y = (X, Z) ∈ Rd+k , and let p be the probability density
of Y . For each i = 1, . . . , l, let ∆ = δei where ei is a l-dimensional unit vector whose i-th
component is one and all the others are zero, and δ is a small scalar. Let Z 0 = Fβ+δ (),
X 0 = G(Z 0 ) and Y 0 = (X 0 , Z 0 ) so that Y 0 is a random variable transformed from Y by
˜
0 fβ (X)
Y =Y + ∆ + o(δ).
fβ (Z)
Let p0 be the probability density of Y 0 . For an arbitrary y 0 = (x0 , z 0 ) ∈ Rd+k , let y 0 =
f˜ (x)
y + fβ (z) ∆ + o(δ) and y = (x, z). Then we have
β
36
Weakly Supervised Disentangled Generative Causal Representation Learning
R
Proof [Proof of Lemma 6] Recall the objective DKL (q, p) = q(x, z) log(p(x, z)/q(x, z))dxdz.
Denote its integrand by `(q, p). Let `02 (q, p) = ∂`(q, p)/∂p. We have
∇β `(q(x, z), p(x, z)) = `02 (q(x, z), p(x, z))∇β pθ,β (x, z)
where ∇β pθ,β (x, z) is computed in Lemma 13.
Besides, we have
∇x · [`02 (q, p)p(x, z)f˜β (x)] = `02 (q, p)p(x, z)∇ · f˜β (x)
+ `0 (q, p)∇x p(x, z) · f˜β (x)
2
+ ∇x `02 (q, p)p(x, z)f˜β (x),
∇z · [`02 (q, p)p(x, z)fβ (z)] = `02 (q, p)p(x, z)∇ · fβ (z)
+ `02 (q, p)∇p(x, z) · fβ (z)
+ ∇`02 (q, p)p(x, z)fβ (z).
Thus,
Z Z
∇β Lgen = ∇β `(q(x, z), p(x, z))dxdz = p(x, z)[∇x `02 (q, p)f˜β (x) + ∇z `02 (q, p)fβ (z)]
where we have ∇x `02 (q, p) = s(x, z)∇x D∗ (x, z) and ∇x `02 (q, p) = s(x, z)∇z D∗ (x, z).
Hence
h i
∇β Lgen = −E(x,z)∼p(x,z) s(x, z)(∇x D∗ (x, z)> f˜β (x) + ∇z D∗ (x, z)> fβ (z))
h i
x=G(F ())
= −E s(x, z)(∇x D∗ (x, z)> ∇β G(Fβ ()) + ∇z D∗ (x, z)> ∇β Fβ ())|z=Fβ ()β .
37
Shen, Liu, Dong, Lian, Chen, and Zhang
• Line 4: FactorVAE metric samples all factors independently from Uniform distribu-
tions, which does not match (and can be far away) from the true distribution of the
causal factors. Instead, we sample the factors following the true SCM and hence
respect the data distribution.
38
Weakly Supervised Disentangled Generative Causal Representation Learning
• Lines 10-12: FactorVAE metric uses the error rate of the majority-vote classifier as
the metric, because in an unsupervised setting, one does not know which factor each
representation captures. In contrast, the weakly-supervised setting can guarantee the
alignment between each representation and a particular factor. Thus, we do not need
the majority-vote classifier to identify this correspondence. Instead, we directly check
whether the dimension with the lowest empirical variance matches the given index k.
As we notice, this metric is limited in that it not only requires the ground-truth factors of
data for sufficient coverage of the data distribution as previous metrics do, but also requires
the ground-truth SCM, which only happens in synthetic data. Nevertheless, in this work, we
only use such a metric to provide evaluations and justification on the relationship between
causal disentanglement and performance in downstream tasks. We leave a widely-applied
quantitive metric for causal disentanglement to future work.
Figure 14 shows the scatter plots of the metrics that we considered in downstream tasks
(Section 5.2) and the metric for causal disentanglement (with M = 200 and N = 50). Each
metric is used to evaluate seven disentanglement models, including S-β-VAE, S-TCVAE,
S-GraphVAE, and multiple DEAR-LIN models with λ = 0.1, 1, 5, 10. All models are trained
using fully supervised labels and GraphVAE and DEAR are given the true graph structure.
The network architectures for the encoders and decoders are all the same. We observe a
positive correlation between causal disentanglement and performance in downstream tasks,
which indicates that the learned representations with a higher disentanglement score tend
to perform better in terms of sample efficiency and distributional robustness in downstream
tasks. In particular, we notice that the small sample accuracy and worst-case accuracy
benefit the most from better causal disentanglement for the corresponding fitted lines have
the largest scope.
0.95
Distributional robustness metrics
Method
Sample efficiency metrics
0.8 Method
DEAR
DEAR
S−beta−VAE
S−beta−VAE
0.90 S−GraphVAE
S−GraphVAE
S−TCVAE
S−TCVAE
Metric 0.6
Metric
0.85 Efficiency
Average accuracy
Large sample accuracy
Worst−case accuracy
Small sample accuracy
0.4
0.80
0.4 0.6 0.8 1.0 0.5 0.6 0.7 0.8 0.9 1.0
Disentanglement Disentanglement
Figure 14: Relationship between causal disentanglement and performance in downstream tasks.
39
Shen, Liu, Dong, Lian, Chen, and Zhang
We comment on the two forms of supervision that may be available and commonly consid-
ered in literature for the task of disentangled representation learning.
• Form 1 (direct and few labels): in some scenarios, we may have some conceptual
knowledge about the data in the sense that we know the concepts of the underlying
generative factors of data, especially those concepts that we are interested in. In such
cases, a weakly supervised setting is feasible where only a few samples have annotated
labels of the factors, since at least manual labeling of a few examples is practical. A
representative work uses this form of supervision is Locatello et al. (2020b).
Both settings have some real applications and limitations which make them comple-
mentary. On one hand, Form 2 in general tends to require “weaker” supervision than Form
1 in the sense that it does not require direct annotations of the true factors themselves.
Thus, efforts towards general provable disentanglement should be put in studying along
Form 2. However, in fact, the auxiliary observed variables in Form 2 also require certain
knowledge on the true factors in order to verify the mathematical assumptions required
in identifiability, e.g. the variability condition in Khemakhem et al. (2020). Intuitively,
the auxiliary variables which can guarantee disentanglement should have enough variability
and correlation with the true factors. In addition, current identifiability theory with Form
2 still assumes relatively strong and limited structure assumption on the true factors, e.g.,
conditional independence in Khemakhem et al. (2020).
On the other hand, current research on disentanglement mostly focuses on the scenarios
where we indeed have some conceptual knowledge on the true factors, which makes Form 1 at
least a feasible and practical setting. For simple structures of true factors (e.g., independence
or conditional independence, as assumed in most previous work), existing methods with
Form 1 can achieve disentanglement, which is much more straightforward compared to
provable disentanglement with supervision of Form 2. However, for more complex structures
(e.g., a causal graph, as considered in our paper), existing methods using independent
or conditionally independent priors generally cannot identify disentanglement even with
supervision in Form 1, as shown in our Proposition 4. In particular, existing formulations
(e.g., Locatello et al. (2020b)) in general cannot even reach the optimum of the supervised
loss, so they cannot disentangle. To this end, our paper proposes a bidirectional generative
model with an SCM prior trained using a GAN-type algorithm, which resolves this problem
under the clearly stated setup and assumptions.
40
Weakly Supervised Disentangled Generative Causal Representation Learning
where A denotes the adjacency matrix, P a(Zi ; A) denotes the set of parents of node Zi ,
and i is the exogenous noise. Learning of an SCM consists of structure learning of A and
parameter estimation of all the assignments fi , i = 1, . . . , m, in the SCM, i.e., how each
node is generated given its parents and exogenous noise. When given the underlying causal
structure, standard parameter estimation methods like maximum likelihood estimation can
yield a consistent estimator of the true SCM assignments from the observational data:
Note that an intervention can be defined as operations that modify a subset of assign-
ments in (22), e.g., changing i , or setting fi (and thus Zi ) to a constant (Pearl et al., 2000;
Schölkopf, 2019). Therefore, with the estimated SCM (23) at hand, we can sample from
any interventional distributions.
We illustrate this through some experimental results shown in Figure 15. In (a), we
intervene on the two factors bald and gender. In each line, we keep gender = female and
gradually increase the probability of them being bald. Particularly in the red box, we
obtain images of bald female faces which have never been seen from the observational data.
In (b), we intervene on beard and gender to generate images of female with beard which
are shown in the red box. In (c), we show some generated samples that gradually wear
(sun)glasses, while in the training data, there are only images with or without glasses but
no intermediate states. In (d), we intervene on all four factors. In each line, the image in
the middle follows the true SCM (described later in Appendix F) so that the factors satisfy
the projection law. Then we change the value of only one factor while keeping others
fixed, which leads to samples not satisfying the projection law. In summary, we see that
although these interventions are not appearing in the observational data, DEAR is able to
generate samples from such interventional distributions, suggesting its generalizability to
unseen interventions.
More systematic analysis on the out-of-distribution generalizability of the encoder is to
be explored in future work. One potential direction is to utilize the generalizability of the
generator to unseen interventions to improve the OOD performance of the encoder. Along
41
Shen, Liu, Dong, Lian, Chen, and Zhang
(c) glasses: gradually wearing (sun)glasses (d) Images not following the projection law
this direction, for example, Sauer and Geiger (2021) recently combined disentangled gener-
ative models and out-of-distribution classification, but adopted a different disentanglement
framework.
42
Weakly Supervised Disentangled Generative Causal Representation Learning
43
Shen, Liu, Dong, Lian, Chen, and Zhang
ξ2
ξ1
ξ3
ξ4
Figure 16: Generative factors of the Pendulum data set. ξ1 : pendulum angle, ξ2 : light angle,
ξ3 : shadow length, ξ4 : shadow position.
In downstream tasks, for BGMs with an encoder, we train a two-level MLP classifier
with 100 hidden nodes using Adam with a learning rate of 1 × 10−2 and a mini-batch size
of 128. Models were trained for around 150 epochs on CelebA, 600 epochs on Pendulum,
and 50 epochs on MPI3D on NVIDIA RTX 2080 Ti.
Description of the Pendulum data set. In Figure 16, we illustrate the generative
factors of the synthesized Pendulum data set, following Yang et al. (2021). Given the
pendulum angle(ξ1 ) and light angle(ξ2 ), following the projection law, one can determine the
shadow length(ξ3 ) and shadow position(ξ4 ). Note that we consider the parallel light in our
simulator. Specifically, define some constants: cx = 10, cy = 10.5 are the axis’s of the center
(pendulum origin); lp = 9.5 be the pendulum length (including the red ball); the bottom
line of a single plot corresponds to y = b with base b = −0.5. Then the ground-truth
structural causal model is expressed as follows.
ξ1 ∼ U(π/4, π/2)
ξ2 ∼ U(0, π/4)
cy −lp cos ξ1 −b cy −b
ξ3 = cx + lp sin ξ1 − tan ξ2 − cx − tan ξ2
cy −lp cos ξ1 −b cy −b
ξ4 = cx + lp sin ξ1 − tan ξ2 + cx − tan ξ2 /2.
We find Gaussians are expressive enough as unexplained noises, so we set h as the identity
mapping. As mentioned in Section 4.1 we require the invertibility of f . We implement both
linear and nonlinear ones. For a linear f , we formally refer to f (z) = W z + b, where W and
44
Weakly Supervised Disentangled Generative Causal Representation Learning
b are learnable weights and biases. Note that W is a diagonal matrix to model the element-
wise transformation. Its inverse function can be easily computed by f −1 (z) = W −1 (z − b).
For a non-linear f , we use piece-wise linear functions defined by
Na
X
[f ([z]i )]i = [w0 ]i [z]i + [wt ]i ([z]i − ai )I([z]i ≥ ai ) + [b]i
t=1
where a0 < a1 < · · · < aNa are the points of division, I(·) is the indicator function, and
{b, wt : t = 0, . . . , Na } is the set of learnable parameters. According to the denseness
of piecewise linear functions in C[0, 1] (Shekhtman, 1982), the family of such piece-wise
linear functions is expressive enough to model general element-wise non-linear invertible
transformations.
Network architectures. We follow the architectures used in Shen et al. (2020). Specif-
ically, for such realistic data, we adopt the SAGAN (Zhang et al., 2019) architecture for
D and G. The D network consists of three modules as shown in Figure 17(a) and de-
tailed described in Shen et al. (2020). Architectures for network G and Dx are given in
Figure 17(b-c) and Table 4. The encoder architecture is the ResNet50 (He et al., 2016)
followed by a 4-layer MLP of size 1024 after ResNet’s global average pooling layer.
Table 4: SAGAN architecture (k = 100 for CelebA and k = 6 for Pendulum and ch = 32).
(a) Generator (b) Discriminator module Dx
45
Joint discriminator modules
Shen, Liu, Dong, Lian, Chen, and Zhang
Data x Dx sx
fx
Dxz sxz Score
D(x, z)
fz
Latent z Dz sz
Generator ResBlock up
(a) Discriminator ResBlock down
Batch-norm
ReLU
ReLU
3x3 Conv
Upsample Upsample
1x1 Conv Average
1x1 Conv 3x3 Conv pooling
Average
Batch-norm pooling ReLU
Average
Add 3x3 Conv Add
pooling
(b) (c)
Figure 17: (a) Architecture of the discriminator D(x, z); (b) A residual block (up scale) in
the SAGAN generator where we use nearest neighbor interpolation for Upsam-
pling; (c) A residual block (down scale) in the SAGAN discriminator.
Note that this implementation follows the original one: z|x, parent is obtained by precision-
weighted fusion in He et al. (2018). Since our factor dependency are explicit, we use 32
latent dimension for more efficient optimization.
For the supervised regularizer, we use λ = 1000 for a balance of generative modeling
and supervised regularizer. The ERM ResNet is trained using the same optimizer with
a learning rate of 1 × 10−4 . We run the public source code from https://github.com/
mkocaoglu/CausalGAN to produce the results of CausalGAN.
46
Weakly Supervised Disentangled Generative Causal Representation Learning
underlying structures on two data sets: Pendulum in Figure 2(a), CelebA-Smile in Fig-
ure 2(b), and CelebA-Attractive in Figure 2(c). Note that the ordering of the rows in the
traversals below matches the indices in Figure 2.
47
Shen, Liu, Dong, Lian, Chen, and Zhang
Figure 18: Results of DEAR. On the left we present the traditional latent traversals (the first type
of intervention stated in Section 5.1) which show the disentanglement. On the right we
show the results of intervening on one latent variable from which we see the consequent
changes of the others (the second type of intervention). Specifically intervening on the
cause variable influences the effect variables while intervening on effect variables makes
no difference to the causes.
48
Weakly Supervised Disentangled Generative Causal Representation Learning
Figure 19: Traversal results of baseline methods. We see that entanglement occurs and
some factors are not captured by the generative models (traversing on some
dimensions of the latent vector makes no difference in the decoded images.)
Besides, the generated images from VAEs are blurry.
49
Shen, Liu, Dong, Lian, Chen, and Zhang
Figure 20: Traversal results of baseline methods. CausalGAN uses the binary factors as
the conditional attributes, so the traversals (a-b) appear some sudden changes.
In contrast, we regard the continuous logit of binary labels as the underlying
factors and hence enjoy smooth manipulations. In addition, the controllability of
CausalGAN is also limited, since entanglement still exists. Results of S-VAEs are
explained in Figure 19. The traversal of S-GraphVAE on Pendulum looks better
than those of S-VAEs, especially in the first two factors, while the performance
on CelebA is poor. Besides, S-GraphVAE has poor generation quality.
50
Weakly Supervised Disentangled Generative Causal Representation Learning
References
Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk
minimization. arXiv preprint arXiv:1907.02893, 2019.
Yoshua Bengio, Aaron Courville, and Pascal Vincent. Representation learning: A review
and new perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence,
35(8):1798–1828, 2013.
Yoshua Bengio, Tristan Deleu, Nasim Rahaman, Nan Rosemary Ke, Sebastien Lachapelle,
Olexa Bilaniuk, Anirudh Goyal, and Christopher Pal. A meta-transfer objective for
learning to disentangle causal mechanisms. In International Conference on Learning
Representations, 2020. URL https://openreview.net/forum?id=ryxWIgBFPS.
Christopher P. Burgess, Irina Higgins, Arka Pal, Loı̈c Matthey, Nicholas Watters, Guillaume
Desjardins, and Alexander Lerchner. Understanding disentangling in beta-vae. NeurIPS
Workshop of Learning Disentangled Features, 2017.
Zachary Charles and Dimitris Papailiopoulos. Stability and generalization of learning algo-
rithms that converge to global optima. In International Conference on Machine Learning,
pages 745–754. PMLR, 2018.
Tian Qi Chen, Xuechen Li, Roger B. Grosse, and David K. Duvenaud. Isolating sources
of disentanglement in variational autoencoders. In Advances in Neural Information Pro-
cessing Systems, 2018.
Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, and Pieter Abbeel.
Infogan: Interpretable representation learning by information maximizing generative ad-
versarial nets. In Advances in Neural Information Processing Systems, pages 2172–2180,
2016.
David Maxwell Chickering. Optimal structure identification with greedy search. Journal of
machine learning research, 3(Nov):507–554, 2002.
Ishita Dasgupta, Jane Wang, Silvia Chiappa, Jovana Mitrovic, Pedro Ortega, David Raposo,
Edward Hughes, Peter Battaglia, Matthew Botvinick, and Zeb Kurth-Nelson. Causal
reasoning from meta-reinforcement learning. arXiv preprint arXiv:1901.08162, 2019.
Andrea Dittadi, Frederik Träuble, Francesco Locatello, Manuel Wuthrich, Vaibhav Agrawal,
Ole Winther, Stefan Bauer, and Bernhard Schölkopf. On the transfer of disentangled
representations in realistic settings. In International Conference on Learning Represen-
tations, 2021. URL https://openreview.net/forum?id=8VXvj1QNRl1.
Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. Adversarial feature learning. In
International Conference on Learning Representations, 2017.
51
Shen, Liu, Dong, Lian, Chen, and Zhang
Vincent Dumoulin, Ishmael Belghazi, Ben Poole, Alex Lamb, Martı́n Arjovsky, Olivier
Mastropietro, and Aaron C. Courville. Adversarially learned inference. In International
Conference on Learning Representations, 2017.
Rick Durrett. Probability: theory and examples, volume 49. Cambridge university press,
2019.
Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil
Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in
Neural Information Processing Systems, pages 2672–2680, 2014.
Jiawei He, Yu Gong, Joseph Marino, Greg Mori, and Andreas Lehrmann. Variational
autoencoders with jointly optimized latent dependency structure. In International Con-
ference on Learning Representations, 2018.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for
image recognition. In IEEE International Conference on Computer Vision and Pattern
Recognition, pages 770–778, 2016.
Irina Higgins, Loı̈c Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew
Botvinick, Shakir Mohamed, and Alexander Lerchner. beta-vae: Learning basic visual
concepts with a constrained variational framework. In International Conference on Learn-
ing Representations, 2017.
Tero Karras, Samuli Laine, and Timo Aila. A style-based generator architecture for gener-
ative adversarial networks. In IEEE International Conference on Computer Vision and
Pattern Recognition, pages 4401–4410, 2019.
Nan Rosemary Ke, Olexa Bilaniuk, Anirudh Goyal, Stefan Bauer, Hugo Larochelle, Bern-
hard Schölkopf, Michael C Mozer, Chris Pal, and Yoshua Bengio. Learning neural causal
models from unknown interventions. arXiv preprint arXiv:1910.01075, 2019.
Nan Rosemary Ke, Aniket Rajiv Didolkar, Sarthak Mittal, Anirudh Goyal, Guillaume La-
joie, Stefan Bauer, Danilo Jimenez Rezende, Michael Curtis Mozer, Yoshua Bengio, and
Christopher Pal. Systematic evaluation of causal discovery in visual model based rein-
forcement learning. 2021.
Ilyes Khemakhem, Diederik Kingma, Ricardo Monti, and Aapo Hyvarinen. Variational
autoencoders and nonlinear ica: A unifying framework. In International Conference on
Artificial Intelligence and Statistics, pages 2207–2217, 2020.
52
Weakly Supervised Disentangled Generative Causal Representation Learning
Felix Leeb, Yashas Annadani, Stefan Bauer, and Bernhard Schölkopf. Structural
autoencoders improve representations for generation and transfer. arXiv preprint
arXiv:2006.07796, 2020.
Zinan Lin, Kiran K Thekumparampil, Giulia Fanti, and Sewoong Oh. Infogan-cr and
modelcentrality: Self-supervised model training and selection for disentangling gans. In
International Conference on Machine Learning, 2020.
Chaoyue Liu, Libin Zhu, and Mikhail Belkin. Loss landscapes and optimization in over-
parameterized non-linear systems and neural networks. arXiv preprint arXiv:2003.00307,
2020.
Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in
the wild. In IEEE International Conference on Computer Vision, pages 3730–3738, 2015.
Francesco Locatello, Ben Poole, Gunnar Rätsch, Bernhard Schölkopf, Olivier Bachem, and
Michael Tschannen. Weakly-supervised disentanglement without compromises. In Inter-
national Conference on Machine Learning, 2020a.
Lars Mescheder, Sebastian Nowozin, and Andreas Geiger. Adversarial variational bayes:
Unifying variational autoencoders and generative adversarial networks. In International
Conference on Machine Learning, pages 2391–2400. JMLR. org, 2017.
Mehdi Mirza and Simon Osindero. Conditional generative adversarial nets. arXiv preprint
arXiv:1411.1784, 2014.
Raha Moraffah, Bahman Moraffah, Mansooreh Karami, Adrienne Raglin, and Huan Liu.
Can: A causal adversarial network for learning observational and interventional distribu-
tions. arXiv preprint arXiv:2008.11376, 2020.
53
Shen, Liu, Dong, Lian, Chen, and Zhang
Suraj Nair, Yuke Zhu, Silvio Savarese, and Li Fei-Fei. Causal induction from visual obser-
vations for goal directed tasks. arXiv preprint arXiv:1910.01751, 2019.
Ignavier Ng, AmirEmad Ghassami, and Kun Zhang. On the role of sparsity and dag
constraints for learning linear dags. arXiv preprint arXiv:2006.10201, 2020.
Judea Pearl et al. Models, reasoning and inference. Cambridge, UK: Cambridge University
Press, 2000.
Jonas Peters and Peter Bühlmann. Identifiability of gaussian structural equation models
with equal error variances. Biometrika, 101(1):219–228, 2014.
Boris Teodorovich Polyak. Gradient methods for minimizing functionals. Zhurnal vychisli-
tel’noi matematiki i matematicheskoi fiziki, 3(4):643–653, 1963.
Ali Razavi, Aaron van den Oord, and Oriol Vinyals. Generating diverse high-fidelity images
with vq-vae-2. In Advances in Neural Information Processing Systems, pages 14866–
14876, 2019.
Shiori Sagawa, Pang Wei Koh, Tatsunori B Hashimoto, and Percy Liang. Distributionally
robust neural networks for group shifts: On the importance of regularization for worst-
case generalization. arXiv preprint arXiv:1911.08731, 2019.
Bernhard Schölkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and
Joris Mooij. On causal and anticausal learning. In International Conference on Machine
Learning, 2012.
Boris Shekhtman. Why piecewise linear functions are dense in c [0, 1]. Journal of Approx-
imation Theory, 36(3):265–267, 1982.
Xinwei Shen, Tong Zhang, and Kani Chen. Bidirectional generative modeling using adver-
sarial gradient estimation. arXiv preprint arXiv:2002.09161, 2020.
Rui Shu, Yining Chen, Abhishek Kumar, Stefano Ermon, and Ben Poole. Weakly supervised
disentanglement with guarantees. In International Conference on Learning Representa-
tions, 2020.
Casper Kaae Sønderby, Tapani Raiko, Lars Maaløe, Søren Kaae Sønderby, and Ole Winther.
Ladder variational autoencoders. In Advances in Neural Information Processing Systems,
pages 3738–3746, 2016.
54
Weakly Supervised Disentangled Generative Causal Representation Learning
Peter Spirtes, Clark N Glymour, Richard Scheines, and David Heckerman. Causation,
prediction, and search. MIT press, 2000.
Jan Stühmer, Richard Turner, and Sebastian Nowozin. Independent subspace analysis for
unsupervised learning of disentangled representations. In International Conference on
Artificial Intelligence and Statistics, pages 1200–1210. PMLR, 2020.
Raphael Suter, Djordje Miladinovic, Bernhard Schölkopf, and Stefan Bauer. Robustly disen-
tangled causal mechanisms: Validating deep representations for interventional robustness.
In International Conference on Machine Learning, pages 6056–6065. PMLR, 2019.
Frederik Träuble, Elliot Creager, Niki Kilbertus, Francesco Locatello, Andrea Dittadi,
Anirudh Goyal, Bernhard Schölkopf, and Stefan Bauer. On disentangled representa-
tions learned from correlated data. In International Conference on Machine Learning,
pages 10401–10412. PMLR, 2021.
Aad W Van der Vaart. Asymptotic statistics, volume 3. Cambridge university press, 2000.
Mengyue Yang, Furui Liu, Zhitang Chen, Xinwei Shen, Jianye Hao, and Jun Wang. Causal-
vae: Disentangled representation learning via neural structural causal models. In Proceed-
ings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages
9593–9602, June 2021.
Yue Yu, Jie Chen, Tian Gao, and Mo Yu. Dag-gnn: Dag structure learning with graph
neural networks. In International Conference on Machine Learning, 2019.
Han Zhang, Ian Goodfellow, Dimitris Metaxas, and Augustus Odena. Self-attention gen-
erative adversarial networks. In International Conference on Machine Learning, pages
7354–7363. PMLR, 2019.
Jiji Zhang and Peter Spirtes. Intervention, determinism, and the causal minimality condi-
tion. Synthese, 182(3):335–347, 2011.
Kun Zhang and Aapo Hyvarinen. On the identifiability of the post-nonlinear causal model.
In Proceedings of the 25th Conference on Uncertainty in Artificial Intelligence, 2009.
Tong Zhang. Statistical behavior and consistency of classification methods based on convex
risk minimization. Annals of Statistics, pages 56–85, 2004.
Shengjia Zhao, Jiaming Song, and Stefano Ermon. Learning hierarchical features from
generative models. In International Conference on Machine Learning, 2017.
Xun Zheng, Bryon Aragam, Pradeep K Ravikumar, and Eric P Xing. Dags with no tears:
Continuous optimization for structure learning. Advances in Neural Information Process-
ing Systems, 31, 2018.
55