Skip to content

JadM133/RRAEsTorch

Repository files navigation

RRAEsTorch

A library that offers same functions as RRAEs (originally in JAX) but in PyTorch.

MAIN DIFFERENCE: the number of samples (or batch) is the first dimension here (as usual in Torch) as opposed to the last dimension is the JAX version.

The RAEDME is copied from RRAEs in the following:

Welcome

This repository allows users to train and manipulate Equinox models easily, specifically, Autoencoders.

The library provides trainor classes that allow to train Neural Networks in one line using JAX.

It also provides easy ways to do normalization, and vectorization of matrices during training.

There are also pre-built Autoencoder models, specifically Rank Reduction Autoencoders (RRAEs).

What are RRAEs?

RRAEs or Rank reduction autoencoders are autoencoders include an SVD in the latent space to regularize the bottleneck.

This library presents all the required classes for creating customized RRAEs and training them (other architectures such as Vanilla AEs, IRMAEs and LoRAEs are also available).

Each script is an example of how to train a different model.

To simply train an MLP (from equinox), try this

To train an RRAE on curves (1D) using an MLP, refer to this file To train an RRAE on curves (1D) using an Convolutions, refer to this file To train an RRAE on images, refer to this file To train a VRRAE on images, refer to this file To train with an adaptive bottleneck size refer to this and [this] file(main-adap-CNN.py)

For examples of post-processing and what RRAE trainors can do, refer to this file

General instruction for preparing your own data

In RRAEs.utilities, there's a function called get_data that can import many datasets to test.

If you want to generate your own dataset, you will have to define the following:

x_train: Train input (refer to each script to see the shape)

x_test: Test input (refer to each script to see the shape)

p_train: None (if you don't have any parameters, otherwise, these can be used for interpolation in the latent space)

p_test: Same as p_train

y_train: = x_train for autoencoders

y_test: = x_test for autoencoders

pre_func_inp: lambda x:x (if not needed, this is a function to be applied on batches if memory is not enough to apply over whole dataset)

pre_func_out: lambda x:x (same as above but for output)

kwargs: {} (any other kwargs you might need)

Installation

pip install RRAEsTorch

Or to get the newest changes:

pip install git+https://github.com/JadM133/RRAEsTorch.git

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors