Yet another Disentanglement applied to VAEs introduction
Introduction to VAEs
Variational autoencoders (VAEs) have become a key step for building powerful generative models of images.
VAEs learn to compress high-dimensional inputs like images into a lower-dimensional latent space while retaining as much information as possible. In addition to standard AutoEncoders, this latent space can then be sampled to generate novel examples similar to the training data. For example models like Stable Diffusion leverage VAEs combined with diffusion models to generate high quality images from text descriptions with stunning results. Furthermore, we can analyze VAEs latent space to create a map of learnt concepts and measure the distance between elements.
In this article we assume that the reader already know how VAEs work to focus on interpretability aspects of the model.
Disentanglement
However, VAEs are not useful just for generation purposes. They have also become crucial for building interpretable generative models with disentangled latent spaces. A disentangled latent space is one where each dimension corresponds to a single semantic factor of variation and different dimensions represent non-overlapping concepts. VAEs can be trained and designed specifically to produce latent spaces with disentangled representations.
The key benefit of disentangled latent spaces is that they yield far greater control over generation. A well disentangled latent space is where each latent dimension has a clear semantic meaning, so newly sampled points in the latent space will produce meaningful changes to the generated outputs. This control and interpretability enable applications like controllable text-to-image synthesis, where each latent dimension could correspond to a semantic aspect specified in the prompt, and manipulating that dimension would change that semantic factor in the generated image in isolation.
The key benefit of disentangled latent spaces is that they yield far greater control over generation. Each latent dimension has a clear semantic meaning, so newly sampled points in the latent space will produce meaningful changes to the generated outputs.
VAEs and in particular Beta-VAEs have exceptional disentanglement capabilities compared to many others latent vector models.
Disentanglement metrics
There is a lot of research under the disentanglement umbrella. Most of the publications propose and analyze disentanglement metrics, because there isn’t an obvious way to measure it. According to “Measuring disentanglement: A review of metrics”, disentanglement metrics can be classified in three sets:
- Predictor-based metrics are a type of disentanglement metrics that evaluate the quality of learned representations by assessing how well a supervised predictor (e.g., a classifier or a regressor) can predict the true generative factors or attributes from the learned latent variables. The idea is that if the latent representation is disentangled, it should be easier for a simple predictor to map the learned variables to the original generative factors. Under this set lies DCI (Disentanglement, Completeness and Informativeness), SAP (Attribute Predictability Score) and Explicitness Score
- Information-based metrics are a category of disentanglement metrics that evaluate the quality of learned representations by analyzing the relationship between the information content of the latent variables and the true generative factors. These metrics are grounded in information theory and leverage concepts such as mutual information, entropy, or conditional entropy to quantify the degree of disentanglement in the learned representations. Under this set lies MIG (Mutual Information Gap) and derivatives (JEMMiG, MIG-SUP)
- Intervention-based metrics are a category of disentanglement metrics that evaluate the quality of learned representations via modifying the latent space. The fundamental idea behind these metrics is that if a representation effectively disentangles the generative factors, it should be possible to manipulate these factors independently in the learned latent space. Under this set lies Z-Diff, Z-min Variance, Z-max Variance and IRS (Interventional Robustness Score)
Application
In this example, we will use PyTorch to train a Beta-VAE model on the dSprites dataset and evaluate the learned representation using DCI (Disentanglement, Completeness, and Informativeness) metric
1. Define the Dsprites dataset
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
class DspritesDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.data = np.load(os.path.join(data_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'), encoding='bytes')
self.images = self.data['imgs']
self.latents_classes = self.data['latents_classes']
self.latents_sizes = self.data['metadata'][()][b'latents_sizes']
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
factors_dict = {}
for i, factor_class in enumerate(self.latents_classes):
factor_name = factor_class.decode('utf-8')
factor_size = self.latents_sizes[i]
factor_array = np.zeros(factor_size)
factors_dict[factor_name] = torch.from_numpy(factor_array).float()
return image, factors_dict
2. Define the Beta-VAE model in Pytorch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
# Define convolutional layers
self.conv1 = nn.Conv2d(1, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
# Define fully connected layers for mean and log variance
self.fc_mu = nn.Linear(256 * 2 * 2, latent_dim)
self.fc_logvar = nn.Linear(256 * 2 * 2, latent_dim)
def forward(self, x):
# Apply convolutional layers with ReLU activation
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
# Flatten tensor and compute mean and log variance
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
# Define fully connected layer
self.fc = nn.Linear(latent_dim, 256 * 2 * 2)
# Define transposed convolutional layers
self.conv1 = nn.ConvTranspose2d(256, 128, 4, stride=2)
self.conv2 = nn.ConvTranspose2d(128, 64, 4, stride=2)
self.conv3 = nn.ConvTranspose2d(64, 32, 4, stride=2)
self.conv4 = nn.ConvTranspose2d(32, 1, 4, stride=2)
def forward(self, z):
# Apply fully connected layer
z = self.fc(z)
# Reshape tensor and apply transposed convolutional layers with ReLU activation
z = z.view(z.size(0), 256, 2, 2)
z = F.relu(self.conv1(z))
z = F.relu(self.conv2(z))
z = F.relu(self.conv3(z))
# Apply final transposed convolutional layer with sigmoid activation
x_recon = torch.sigmoid(self.conv4(z))
return x_recon
class BetaVAE(nn.Module):
def __init__(self, latent_dim, beta):
# we assume that the prior distribution is a Normal(mu=0, scale=1)
super(BetaVAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
self.beta = beta
def forward(self, x):
# Encode input and obtain mean and log variance
mu, logvar = self.encoder(x)
# Reparameterize to sample latent variable
z = self.reparameterize(mu, logvar)
# Decode latent variable to obtain reconstructed input
x_recon = self.decoder(z)
return x_recon, mu, logvar
def reparameterize(self, mu, logvar):
#Sample from the learned distribution using the reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def loss_function(self, x, x_recon, mu, logvar):
# Compute reconstruction loss using binary cross entropy
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / x.size(0)
# Compute KL-divergence loss
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
# Compute total loss with beta weight on KL-divergence loss
loss = recon_loss + self.beta * kld_loss
return loss
3. Train the Beta-VAE
from torch.utils.data import DataLoader
from torch.optim import Adam
# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 1e-4
latent_dim = 10
beta = 4.0
model = BetaVAE(latent_dim, beta)
dsprites = DspritesDataset('dataset/dsprites')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
dataloader = DataLoader(dsprites, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch in dataloader:
batch = batch.to(device)
images = batch[0]
optimizer.zero_grad()
x_recon, mu, logvar = model(batch)
loss = model.loss_function(batch, x_recon, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
train_loss /= len(dsprites)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}")
4. Compute DCI
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mutual_info_score, accuracy_score
def dci_score(encoder, data_loader, test_size=0.2):
# Encode the entire dataset
mu_list, factors_list = [], []
with torch.no_grad():
for images, factors in data_loader: #iterate over data loader
mu, log_var = encoder.forward(images) #get latent representation
mu_list.append(mu)
factors_list.append(factors.values())
mu_all = torch.cat(mu_list, dim=0) #concatenate all latent representations
factors_all = torch.cat(factors_list, dim=0) #concatenate all factors
# Compute DCI scores
scores = _compute_dci_scores(mu_all.numpy(), factors_all.numpy(), test_size)
return scores #return disentanglement, completeness, informativeness scores
def compute_dci_disentanglement(params):
"""
>>> abs(compute_dci_disentanglement(np.array([[0.5, 0.5], [0.5, 0.5]]))) <= 1e-5 #tolerance of 1e-5 because log constant
True
:param params: Linear model weights matrix
:return: the dci disentanglement score
"""
L, M = params.shape #get shape of weight matrix
P = params / params.sum(axis=1, keepdims=True) #normalize rows to get P matrix
log_P = np.log(P + 1e-5) #take log of P matrix
entropy_Pi = (- P * (log_P / np.log(M))).sum(axis=1) #compute entropies of P
D_i = 1 - entropy_Pi #compute disentanglement of latent dimensions
phi_i = (params.sum(axis=1) / params.sum()) #compute usage of latent dimensions
return (phi_i * D_i).sum() #weighted sum of per-dimension disentanglements
def compute_dci_completeness(params):
L, M = params.shape
R = params / params.sum(axis=0, keepdims=True) #normalize columns to get R matrix
entropy_Rj = (- R * (np.log(R + 1e-5) / np.log(L))).sum(axis=0) #compute entropies of R
C_j = 1 - entropy_Rj #compute completeness of factors of variation
return C_j.mean() #average all factor completnesses
def _compute_dci_scores(representations, factors, test_size, random_state=None):
# Train/test split
num_factors = factors.shape[1]
train_idx, test_idx = train_test_split(np.arange(representations.shape[0]), test_size=test_size, random_state=None)
coef_matrix = np.zeros((latent_dim, num_factors))
informativeness []
for j in range(num_factors):
clf = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=1000)
clf.fit(representations[train_idx], factors[train_idx, j])
coef_matrix[:, j] = clf.coef_
pred_factors = clf.predict(representations[test_idx])
informativeness.append(accuracy_score(pred_factors, factors[test_dx, j]))
# Compute DCI scores
disentanglement = compute_dci_disentanglement(coef_matrix)
completeness = compute_dci_completeness(coef_matrix)
informativeness = np.mean(informativeness)
return {'disentanglement': disentanglement, 'completeness': completeness, 'informativeness': informativeness}
dci = dci_score(vae.encoder, dataloader)
print("DCI score:", dci)
However, it is important to note that the DCI metric is dependent on the choice of the model and its hyperparameters. Different models and hyperparameters can result in different degrees of disentanglement, completeness, and independence scores.
As an example of how the DCI metric can be dependent on the model choice and its hyperparameters, let’s consider the case of logistic regression versus random forest for a binary classification task. Logistic regression is a linear model that learns a decision boundary in the input space, while random forest is a tree-based model that learns a nonlinear decision boundary by partitioning the feature space into regions.
In this scenario, the choice of the model can have a significant impact on the disentanglement of learned representations and consequently on the DCI scores. For instance, logistic regression may struggle to disentangle complex and nonlinear patterns in the data, resulting in lower Informativeness scores. On the other hand, random forest can capture more complex interactions between the input features, potentially leading to higher Informativeness scores.