Yet another Disentanglement applied to VAEs introduction

Michele De Vita
7 min readMay 9, 2023

--

Introduction to VAEs

Variational autoencoders (VAEs) have become a key step for building powerful generative models of images.

VAEs learn to compress high-dimensional inputs like images into a lower-dimensional latent space while retaining as much information as possible. In addition to standard AutoEncoders, this latent space can then be sampled to generate novel examples similar to the training data. For example models like Stable Diffusion leverage VAEs combined with diffusion models to generate high quality images from text descriptions with stunning results. Furthermore, we can analyze VAEs latent space to create a map of learnt concepts and measure the distance between elements.
In this article we assume that the reader already know how VAEs work to focus on interpretability aspects of the model.

VAEs Latent space map for MNIST dataset. Each color is a digit

Disentanglement

However, VAEs are not useful just for generation purposes. They have also become crucial for building interpretable generative models with disentangled latent spaces. A disentangled latent space is one where each dimension corresponds to a single semantic factor of variation and different dimensions represent non-overlapping concepts. VAEs can be trained and designed specifically to produce latent spaces with disentangled representations.

The key benefit of disentangled latent spaces is that they yield far greater control over generation. A well disentangled latent space is where each latent dimension has a clear semantic meaning, so newly sampled points in the latent space will produce meaningful changes to the generated outputs. This control and interpretability enable applications like controllable text-to-image synthesis, where each latent dimension could correspond to a semantic aspect specified in the prompt, and manipulating that dimension would change that semantic factor in the generated image in isolation.

The key benefit of disentangled latent spaces is that they yield far greater control over generation. Each latent dimension has a clear semantic meaning, so newly sampled points in the latent space will produce meaningful changes to the generated outputs.

VAEs and in particular Beta-VAEs have exceptional disentanglement capabilities compared to many others latent vector models.

Example of well disentangled latent space

Disentanglement metrics

There is a lot of research under the disentanglement umbrella. Most of the publications propose and analyze disentanglement metrics, because there isn’t an obvious way to measure it. According to “Measuring disentanglement: A review of metrics”, disentanglement metrics can be classified in three sets:

  • Predictor-based metrics are a type of disentanglement metrics that evaluate the quality of learned representations by assessing how well a supervised predictor (e.g., a classifier or a regressor) can predict the true generative factors or attributes from the learned latent variables. The idea is that if the latent representation is disentangled, it should be easier for a simple predictor to map the learned variables to the original generative factors. Under this set lies DCI (Disentanglement, Completeness and Informativeness), SAP (Attribute Predictability Score) and Explicitness Score
  • Information-based metrics are a category of disentanglement metrics that evaluate the quality of learned representations by analyzing the relationship between the information content of the latent variables and the true generative factors. These metrics are grounded in information theory and leverage concepts such as mutual information, entropy, or conditional entropy to quantify the degree of disentanglement in the learned representations. Under this set lies MIG (Mutual Information Gap) and derivatives (JEMMiG, MIG-SUP)
  • Intervention-based metrics are a category of disentanglement metrics that evaluate the quality of learned representations via modifying the latent space. The fundamental idea behind these metrics is that if a representation effectively disentangles the generative factors, it should be possible to manipulate these factors independently in the learned latent space. Under this set lies Z-Diff, Z-min Variance, Z-max Variance and IRS (Interventional Robustness Score)

Application

In this example, we will use PyTorch to train a Beta-VAE model on the dSprites dataset and evaluate the learned representation using DCI (Disentanglement, Completeness, and Informativeness) metric

1. Define the Dsprites dataset

import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url

class DspritesDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.data = np.load(os.path.join(data_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'), encoding='bytes')
self.images = self.data['imgs']
self.latents_classes = self.data['latents_classes']
self.latents_sizes = self.data['metadata'][()][b'latents_sizes']

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
factors_dict = {}
for i, factor_class in enumerate(self.latents_classes):
factor_name = factor_class.decode('utf-8')
factor_size = self.latents_sizes[i]
factor_array = np.zeros(factor_size)
factors_dict[factor_name] = torch.from_numpy(factor_array).float()
return image, factors_dict

2. Define the Beta-VAE model in Pytorch

import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
# Define convolutional layers
self.conv1 = nn.Conv2d(1, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
# Define fully connected layers for mean and log variance
self.fc_mu = nn.Linear(256 * 2 * 2, latent_dim)
self.fc_logvar = nn.Linear(256 * 2 * 2, latent_dim)

def forward(self, x):
# Apply convolutional layers with ReLU activation
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
# Flatten tensor and compute mean and log variance
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar

class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
# Define fully connected layer
self.fc = nn.Linear(latent_dim, 256 * 2 * 2)
# Define transposed convolutional layers
self.conv1 = nn.ConvTranspose2d(256, 128, 4, stride=2)
self.conv2 = nn.ConvTranspose2d(128, 64, 4, stride=2)
self.conv3 = nn.ConvTranspose2d(64, 32, 4, stride=2)
self.conv4 = nn.ConvTranspose2d(32, 1, 4, stride=2)

def forward(self, z):
# Apply fully connected layer
z = self.fc(z)
# Reshape tensor and apply transposed convolutional layers with ReLU activation
z = z.view(z.size(0), 256, 2, 2)
z = F.relu(self.conv1(z))
z = F.relu(self.conv2(z))
z = F.relu(self.conv3(z))
# Apply final transposed convolutional layer with sigmoid activation
x_recon = torch.sigmoid(self.conv4(z))
return x_recon

class BetaVAE(nn.Module):
def __init__(self, latent_dim, beta):
# we assume that the prior distribution is a Normal(mu=0, scale=1)
super(BetaVAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
self.beta = beta

def forward(self, x):
# Encode input and obtain mean and log variance
mu, logvar = self.encoder(x)
# Reparameterize to sample latent variable
z = self.reparameterize(mu, logvar)
# Decode latent variable to obtain reconstructed input
x_recon = self.decoder(z)
return x_recon, mu, logvar

def reparameterize(self, mu, logvar):
#Sample from the learned distribution using the reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def loss_function(self, x, x_recon, mu, logvar):
# Compute reconstruction loss using binary cross entropy
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / x.size(0)
# Compute KL-divergence loss
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
# Compute total loss with beta weight on KL-divergence loss
loss = recon_loss + self.beta * kld_loss
return loss

3. Train the Beta-VAE

from torch.utils.data import DataLoader
from torch.optim import Adam

# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 1e-4
latent_dim = 10
beta = 4.0

model = BetaVAE(latent_dim, beta)
dsprites = DspritesDataset('dataset/dsprites')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = Adam(model.parameters(), lr=learning_rate)
dataloader = DataLoader(dsprites, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch in dataloader:
batch = batch.to(device)
images = batch[0]
optimizer.zero_grad()
x_recon, mu, logvar = model(batch)
loss = model.loss_function(batch, x_recon, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()

train_loss /= len(dsprites)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}")

4. Compute DCI

import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mutual_info_score, accuracy_score


def dci_score(encoder, data_loader, test_size=0.2):
# Encode the entire dataset
mu_list, factors_list = [], []
with torch.no_grad():
for images, factors in data_loader: #iterate over data loader
mu, log_var = encoder.forward(images) #get latent representation
mu_list.append(mu)
factors_list.append(factors.values())

mu_all = torch.cat(mu_list, dim=0) #concatenate all latent representations
factors_all = torch.cat(factors_list, dim=0) #concatenate all factors

# Compute DCI scores
scores = _compute_dci_scores(mu_all.numpy(), factors_all.numpy(), test_size)
return scores #return disentanglement, completeness, informativeness scores

def compute_dci_disentanglement(params):
"""
>>> abs(compute_dci_disentanglement(np.array([[0.5, 0.5], [0.5, 0.5]]))) <= 1e-5 #tolerance of 1e-5 because log constant
True

:param params: Linear model weights matrix
:return: the dci disentanglement score
"""
L, M = params.shape #get shape of weight matrix

P = params / params.sum(axis=1, keepdims=True) #normalize rows to get P matrix

log_P = np.log(P + 1e-5) #take log of P matrix

entropy_Pi = (- P * (log_P / np.log(M))).sum(axis=1) #compute entropies of P

D_i = 1 - entropy_Pi #compute disentanglement of latent dimensions

phi_i = (params.sum(axis=1) / params.sum()) #compute usage of latent dimensions

return (phi_i * D_i).sum() #weighted sum of per-dimension disentanglements


def compute_dci_completeness(params):
L, M = params.shape

R = params / params.sum(axis=0, keepdims=True) #normalize columns to get R matrix
entropy_Rj = (- R * (np.log(R + 1e-5) / np.log(L))).sum(axis=0) #compute entropies of R
C_j = 1 - entropy_Rj #compute completeness of factors of variation

return C_j.mean() #average all factor completnesses

def _compute_dci_scores(representations, factors, test_size, random_state=None):
# Train/test split
num_factors = factors.shape[1]
train_idx, test_idx = train_test_split(np.arange(representations.shape[0]), test_size=test_size, random_state=None)

coef_matrix = np.zeros((latent_dim, num_factors))
informativeness []
for j in range(num_factors):
clf = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)
clf.fit(representations[train_idx], factors[train_idx, j])
coef_matrix[:, j] = clf.coef_
pred_factors = clf.predict(representations[test_idx])
informativeness.append(accuracy_score(pred_factors, factors[test_dx, j]))
# Compute DCI scores
disentanglement = compute_dci_disentanglement(coef_matrix)
completeness = compute_dci_completeness(coef_matrix)
informativeness = np.mean(informativeness)

return {'disentanglement': disentanglement, 'completeness': completeness, 'informativeness': informativeness}

dci = dci_score(vae.encoder, dataloader)
print("DCI score:", dci)

However, it is important to note that the DCI metric is dependent on the choice of the model and its hyperparameters. Different models and hyperparameters can result in different degrees of disentanglement, completeness, and independence scores.

As an example of how the DCI metric can be dependent on the model choice and its hyperparameters, let’s consider the case of logistic regression versus random forest for a binary classification task. Logistic regression is a linear model that learns a decision boundary in the input space, while random forest is a tree-based model that learns a nonlinear decision boundary by partitioning the feature space into regions.

In this scenario, the choice of the model can have a significant impact on the disentanglement of learned representations and consequently on the DCI scores. For instance, logistic regression may struggle to disentangle complex and nonlinear patterns in the data, resulting in lower Informativeness scores. On the other hand, random forest can capture more complex interactions between the input features, potentially leading to higher Informativeness scores.

--

--

No responses yet