Hi!
I am having the following problem: when I recover the origin data ($y_i$) and plot the latent representations, I verify the hierarchy used in the Branching Diffusion process is not preserved in the latent space as in the paper's figures.
I wonder if this is related to the issue #3 since I also encounter the same problem of having negative KL values. I would like to know if you might have any idea of the origin of this problem, please.
I am trying to reproduce the leftmost image in Figure 5 from the paper (which should be about the same as the leftmost image in Figure 9 in the appendix), which contains the embeddings and posterior samples of the Synthetic Dataset for the case where $c=1.2$ .
If I understood it right, the samples used for finding the black crosses are the node values $y_i$, whose values are self.origin_data and hierarchy labels are self.origin_labels in SyntheticDataset class.
In my understanding, the black crosses are the embeddings/deterministic values of self.origin_data, i.e. the mean of the model's posterior (model.qz_x.loc).
For reproducing my visualization:
I used self.origin_data and self.origin_labels from SyntheticDataset class into a new dataloader (which I call origin_loader) and plot their model.qz_x.loc for each of its samples.
Here is how I coded it:
- Create the origin dataloader in
pvae/datasets/datasets.py:
a. Add one line to SyntheticDataset (comment "Added")
class SyntheticDataset(torch.utils.data.Dataset):
def __init__(self, dim, depth, numberOfChildren=2, sigma_children=1, param=1, numberOfsiblings=1, factor_sibling=10):
self.dim = int(dim)
self.root = np.zeros(self.dim)
self.depth = int(depth)
self.sigma_children = sigma_children
self.factor_sibling = factor_sibling
self.param = param
self.numberOfChildren = int(numberOfChildren)
self.numberOfsiblings = int(numberOfsiblings)
self.origin_data, self.origin_labels, self.data, self.labels = self.bst()
# Added
self.origin_dataset = SyntheticDatasetOrigin(self.origin_data, self.origin_labels)
# Normalise data (0 mean, 1 std)
self.data -= np.mean(self.data, axis=0, keepdims=True)
self.data /= np.std(self.data, axis=0, keepdims=True)
b. Create SyntheticDatasetOrigin class:
#Added
class SyntheticDatasetOrigin(torch.utils.data.Dataset):
"""
Dataset of origin nodes
"""
def __init__(self, origin_data, origin_labels):
self.origin_data = origin_data
self.origin_labels = origin_labels
# Normalise data (0 mean, 1 std)
self.origin_data -= np.mean(self.origin_data, axis=0, keepdims=True)
self.origin_data /= np.std(self.origin_data, axis=0, keepdims=True)
def __len__(self):
return len(self.origin_data)
def __getitem__(self, idx):
'''
Generates one sample
'''
data, labels = self.origin_data[idx], self.origin_labels[idx]
return torch.Tensor(data), torch.Tensor(labels)
- Make class
Tree return also origin_dataloader:
class Tree(Tabular):
""" Derive a specific sub-class of a VAE for tree data. """
def __init__(self, params):
super(Tree, self).__init__(params)
def getDataLoaders(self, batch_size, shuffle, device, *args):
kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
print('Load training data...')
dataset = SyntheticDataset(*self.data_size, *map(lambda x: float(x), args))
n_train, n_test = _validate_shuffle_split(len(dataset), test_size=None, train_size=0.7)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_test])
train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=shuffle, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, shuffle=False, **kwargs)
#Added
origin_loader = DataLoader(dataset.origin_dataset, batch_size=1, drop_last=True, shuffle=False, **kwargs)
return train_loader, test_loader, origin_loader
- Write script
visualize_latent_space.py to plot origin_dataloader:
import numpy as np
import sys
import os
import json
import torch
import matplotlib.pyplot as plt
torch.manual_seed(20)
np.random.seed(8) #8,6
# Get the absolute path of the parent directory and add it to sys.path
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, parent_dir)
from pvae.models import VAE_tree # Your model class
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def visualize_latent_space(model, data_loader, device, c):
model.eval()
latent_list = []
label_list = []
with torch.no_grad():
for x, y in data_loader:
x = x.to(device)
# Get the variational posterior q(z|x)
qz_x = model.qz_x(*model.enc(x))
zs = qz_x.loc
latent_list.append(zs.cpu())
label_list.append(y)
# Concatenate batches
latents = torch.cat(latent_list, dim=0).numpy() # Expected shape: (N, latent_dim)
labels = torch.cat(label_list, dim=0).numpy() # Expected shape: (N, label_dim)
# Dictionary to map path tuples to their indices
path_to_index = {tuple(row): i for i, row in enumerate(labels)}
plt.figure(figsize=(8, 8))
ax = plt.gca()
# Plot nodes
ax.scatter(latents[:, 0], latents[:, 1], s=50, zorder=2, c='lightblue', edgecolors='k')
for num_yi in range(len(latents[:, 0])):
ax.annotate(str(num_yi), (latents[:, 0][num_yi], latents[:, 1][num_yi]))
# Draw edges from parent to child
for i, path in enumerate(labels):
path = np.array(path)
non_zero = np.nonzero(path)[0]
if len(non_zero) == 0:
continue # root node, no parent
parent_path = path.copy()
parent_path[non_zero[-1]] = 0 # remove last non-zero element
parent_tuple = tuple(parent_path)
if parent_tuple in path_to_index:
parent_idx = path_to_index[parent_tuple]
latents_parent, y_parent = latents[parent_idx]
latents_child, y_child = latents[i]
ax.plot([latents_parent, latents_child], [y_parent, y_child], 'k-', zorder=1)
plt.title('Latent Space Visualization')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.xlim([-1,1])
plt.ylim([-1,1])
plt.legend(title='Label', loc='best')
plt.grid(True)
#Draw the Poincaré disk boundary (radius = 1/√c)
boundary_radius = 1.0 / (c ** 0.5)
circle = plt.Circle((0, 0), boundary_radius, color='black', fill=False, linestyle='dashed', linewidth=1)
plt.gca().add_artist(circle)
plt.tight_layout()
plt.savefig('latent_space.png')
plt.show()
return latents, labels, X
if __name__ == "__main__":
args = {}
args["model_dir"] = "2025-05-05T14_59_37.077755zsl4w3m7"
#load training arguments from args.json
train_args_path = os.path.join(args["model_dir"], 'args.json') #args.model_dir if from parser
with open(train_args_path, 'r') as f:
train_args = json.load(f)
train_args = AttrDict(train_args)
print(train_args)
print("Model name:", train_args.name)
#determine computation device
train_args.cuda = not train_args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if train_args.cuda else "cpu")
#initialize the model using training arguments
model = VAE_tree(train_args)
model.to(device)
#load the model state
model_path = os.path.join(args["model_dir"], 'model.rar') #args.model_dir if from parser
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
#get dataloaders using model's method; pass data_params properly from training configuration
_, _, origin_loader = model.getDataLoaders(train_args.batch_size, True, device, *train_args.data_params)
#visualize the latent space using the test data and training curvature c
latents, labels, X = visualize_latent_space(model, origin_loader, device, c=float(train_args.c)) #test_loader
The figure I get is the following (displaying only 2 levels of the tree):

Where the annotations are the nodes' indexes
0 ≺ 1
0 ≺ 2
1 ≺ 2
1 ≺ 4
2 ≺ 5
2 ≺ 6
Thank you for your attention!
Hi!$y_i$ ) and plot the latent representations, I verify the hierarchy used in the Branching Diffusion process is not preserved in the latent space as in the paper's figures.
I am having the following problem: when I recover the origin data (
I wonder if this is related to the issue #3 since I also encounter the same problem of having negative KL values. I would like to know if you might have any idea of the origin of this problem, please.
I am trying to reproduce the leftmost image in Figure 5 from the paper (which should be about the same as the leftmost image in Figure 9 in the appendix), which contains the embeddings and posterior samples of the Synthetic Dataset for the case where$c=1.2$ .
If I understood it right, the samples used for finding the black crosses are the node values$y_i$ , whose values are
self.origin_dataand hierarchy labels areself.origin_labelsinSyntheticDatasetclass.In my understanding, the black crosses are the embeddings/deterministic values of
self.origin_data, i.e. the mean of the model's posterior (model.qz_x.loc).For reproducing my visualization:
I used
self.origin_dataandself.origin_labelsfromSyntheticDatasetclass into a new dataloader (which I callorigin_loader) and plot theirmodel.qz_x.locfor each of its samples.Here is how I coded it:
pvae/datasets/datasets.py:a. Add one line to
SyntheticDataset(comment "Added")b. Create
SyntheticDatasetOriginclass:Treereturn alsoorigin_dataloader:visualize_latent_space.pyto plotorigin_dataloader:The figure I get is the following (displaying only 2 levels of the tree):
Where the annotations are the nodes' indexes
0 ≺ 1
0 ≺ 2
1 ≺ 2
1 ≺ 4
2 ≺ 5
2 ≺ 6
Thank you for your attention!