Skip to content

Replicate Synthetic Dataset Latent Space Figure 5 #19

@isacostamaia

Description

@isacostamaia

Hi!
I am having the following problem: when I recover the origin data ($y_i$) and plot the latent representations, I verify the hierarchy used in the Branching Diffusion process is not preserved in the latent space as in the paper's figures.

I wonder if this is related to the issue #3 since I also encounter the same problem of having negative KL values. I would like to know if you might have any idea of the origin of this problem, please.

I am trying to reproduce the leftmost image in Figure 5 from the paper (which should be about the same as the leftmost image in Figure 9 in the appendix), which contains the embeddings and posterior samples of the Synthetic Dataset for the case where $c=1.2$ .

If I understood it right, the samples used for finding the black crosses are the node values $y_i$, whose values are self.origin_data and hierarchy labels are self.origin_labels in SyntheticDataset class.
In my understanding, the black crosses are the embeddings/deterministic values of self.origin_data, i.e. the mean of the model's posterior (model.qz_x.loc).

For reproducing my visualization:

I used self.origin_data and self.origin_labels from SyntheticDataset class into a new dataloader (which I call origin_loader) and plot their model.qz_x.loc for each of its samples.

Here is how I coded it:

  1. Create the origin dataloader in pvae/datasets/datasets.py:
    a. Add one line to SyntheticDataset (comment "Added")
class SyntheticDataset(torch.utils.data.Dataset):
    def __init__(self, dim, depth, numberOfChildren=2, sigma_children=1, param=1, numberOfsiblings=1, factor_sibling=10):
        self.dim = int(dim)
        self.root = np.zeros(self.dim)
        self.depth = int(depth)
        self.sigma_children = sigma_children
        self.factor_sibling = factor_sibling
        self.param = param
        self.numberOfChildren = int(numberOfChildren)
        self.numberOfsiblings = int(numberOfsiblings)  

        self.origin_data, self.origin_labels, self.data, self.labels = self.bst()

        # Added
        self.origin_dataset = SyntheticDatasetOrigin(self.origin_data, self.origin_labels)

        # Normalise data (0 mean, 1 std)
        self.data -= np.mean(self.data, axis=0, keepdims=True)
        self.data /= np.std(self.data, axis=0, keepdims=True)

b. Create SyntheticDatasetOrigin class:

#Added
class SyntheticDatasetOrigin(torch.utils.data.Dataset):
    """
        Dataset of origin nodes
    """
    def __init__(self, origin_data, origin_labels):
        self.origin_data = origin_data
        self.origin_labels = origin_labels

        # Normalise data (0 mean, 1 std)
        self.origin_data -= np.mean(self.origin_data, axis=0, keepdims=True)
        self.origin_data /= np.std(self.origin_data, axis=0, keepdims=True)


    def __len__(self):
        return len(self.origin_data)
    
    def __getitem__(self, idx):
        '''
        Generates one sample
        '''
        data, labels = self.origin_data[idx], self.origin_labels[idx]
        return torch.Tensor(data), torch.Tensor(labels)
  1. Make class Tree return also origin_dataloader:
class Tree(Tabular):
    """ Derive a specific sub-class of a VAE for tree data. """
    def __init__(self, params):
        super(Tree, self).__init__(params)

    def getDataLoaders(self, batch_size, shuffle, device, *args):
        kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
        print('Load training data...')
        dataset = SyntheticDataset(*self.data_size, *map(lambda x: float(x), args))
        n_train, n_test = _validate_shuffle_split(len(dataset), test_size=None, train_size=0.7)
        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_test])
        train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=shuffle, **kwargs)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, shuffle=False, **kwargs)
        #Added
        origin_loader =  DataLoader(dataset.origin_dataset, batch_size=1, drop_last=True, shuffle=False, **kwargs)
        return train_loader, test_loader, origin_loader
  1. Write script visualize_latent_space.py to plot origin_dataloader:
import numpy as np
import sys
import os
import json
import torch
import matplotlib.pyplot as plt

torch.manual_seed(20)
np.random.seed(8) #8,6
# Get the absolute path of the parent directory and add it to sys.path
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, parent_dir)

from pvae.models import VAE_tree  # Your model class

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

def visualize_latent_space(model, data_loader, device, c):

    model.eval()
    latent_list = []
    label_list = []

    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            # Get the variational posterior q(z|x)
            qz_x = model.qz_x(*model.enc(x))
            zs = qz_x.loc
            latent_list.append(zs.cpu())
            label_list.append(y)

    # Concatenate batches
    latents = torch.cat(latent_list, dim=0).numpy()  # Expected shape: (N, latent_dim)
    labels = torch.cat(label_list, dim=0).numpy()      # Expected shape: (N, label_dim)

    # Dictionary to map path tuples to their indices
    path_to_index = {tuple(row): i for i, row in enumerate(labels)}

    plt.figure(figsize=(8, 8))
    ax = plt.gca()

    # Plot nodes
    ax.scatter(latents[:, 0], latents[:, 1], s=50, zorder=2, c='lightblue', edgecolors='k')
    for num_yi in range(len(latents[:, 0])):
        ax.annotate(str(num_yi), (latents[:, 0][num_yi], latents[:, 1][num_yi]))

    # Draw edges from parent to child
    for i, path in enumerate(labels):
        path = np.array(path)
        non_zero = np.nonzero(path)[0]
        if len(non_zero) == 0:
            continue  # root node, no parent
        parent_path = path.copy()
        parent_path[non_zero[-1]] = 0  # remove last non-zero element
        parent_tuple = tuple(parent_path)
        if parent_tuple in path_to_index:
            parent_idx = path_to_index[parent_tuple]
            latents_parent, y_parent = latents[parent_idx]
            latents_child, y_child = latents[i]
            ax.plot([latents_parent, latents_child], [y_parent, y_child], 'k-', zorder=1)


    plt.title('Latent Space Visualization')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.xlim([-1,1])
    plt.ylim([-1,1])

    plt.legend(title='Label', loc='best')
    plt.grid(True)

    #Draw the Poincaré disk boundary (radius = 1/√c)
    boundary_radius = 1.0 / (c ** 0.5)
    circle = plt.Circle((0, 0), boundary_radius, color='black', fill=False, linestyle='dashed', linewidth=1)
    plt.gca().add_artist(circle)

    plt.tight_layout()
    plt.savefig('latent_space.png')
    plt.show()
    return latents, labels, X


if __name__ == "__main__":
    args = {}
    args["model_dir"] = "2025-05-05T14_59_37.077755zsl4w3m7"

    #load training arguments from args.json
    train_args_path = os.path.join(args["model_dir"], 'args.json') #args.model_dir if from parser
    with open(train_args_path, 'r') as f:
        train_args = json.load(f)
    train_args = AttrDict(train_args)
    print(train_args)
    print("Model name:", train_args.name)

    #determine computation device
    train_args.cuda = not train_args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if train_args.cuda else "cpu")

    #initialize the model using training arguments
    model = VAE_tree(train_args)
    model.to(device)
    
    #load the model state
    model_path = os.path.join(args["model_dir"], 'model.rar') #args.model_dir if from parser
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    #get dataloaders using model's method; pass data_params properly from training configuration
    _, _, origin_loader = model.getDataLoaders(train_args.batch_size, True, device, *train_args.data_params)

    #visualize the latent space using the test data and training curvature c
    latents, labels, X = visualize_latent_space(model, origin_loader, device, c=float(train_args.c)) #test_loader

The figure I get is the following (displaying only 2 levels of the tree):

Image
Where the annotations are the nodes' indexes
0 ≺ 1
0 ≺ 2
1 ≺ 2
1 ≺ 4
2 ≺ 5
2 ≺ 6

Thank you for your attention!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions