Title: | Generative Adversarial Nets (GAN) in R |
---|---|
Description: | An easy way to get started with Generative Adversarial Nets (GAN) in R. The GAN algorithm was initially described by Goodfellow et al. 2014 <https://proceedings.neurips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf>. A GAN can be used to learn the joint distribution of complex data by comparison. A GAN consists of two neural networks a Generator and a Discriminator, where the two neural networks play an adversarial minimax game. Built-in GAN models make the training of GANs in R possible in one line and make it easy to experiment with different design choices (e.g. different network architectures, value functions, optimizers). The built-in GAN models work with tabular data (e.g. to produce synthetic data) and image data. Methods to post-process the output of GAN models to enhance the quality of samples are available. |
Authors: | Marcel Neunhoeffer [aut, cre] |
Maintainer: | Marcel Neunhoeffer <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.1.1 |
Built: | 2025-03-11 02:44:30 UTC |
Source: | https://github.com/mneunhoe/rgan |
Provides a class to transform data for RGAN.
Method $new()
initializes a new transformer, method $fit(data)
learns
the parameters for the transformation from data (e.g. means and sds).
Methods $transform()
and $inverse_transform()
can be used to transform
and back transform a data set based on the learned parameters.
Currently, DataTransformer supports z-transformation (a.k.a. normalization)
for numerical features/variables and one hot encoding for categorical
features/variables. In your call to fit you just need to indicate which
columns contain discrete features.
A class to transform (normalize or one hot encode) tabular data for RGAN
new()
Create a new data_transformer object
data_transformer$new()
fit_continuous()
data_transformer$fit_continuous(column = NULL, data = NULL)
fit_discrete()
data_transformer$fit_discrete(column = NULL, data = NULL)
fit()
Fit a data_transformer to data.
data_transformer$fit(data, discrete_columns = NULL)
data
The data set to transform
discrete_columns
Column ids for columns with discrete/nominal values to be one hot encoded.
data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data)
transform_continuous()
data_transformer$transform_continuous(column_meta, data)
transform_discrete()
data_transformer$transform_discrete(column_meta, data)
transform()
Transform data using a fitted data_transformer. (From original format to transformed format.)
data_transformer$transform(data)
data
The data set to transform
data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data)
inverse_transform_continuous()
data_transformer$inverse_transform_continuous(meta, data)
inverse_transform_discrete()
data_transformer$inverse_transform_discrete(meta, data)
inverse_transform()
Inverse Transform data using a fitted data_transformer. (From transformed format to original format.)
data_transformer$inverse_transform(data)
data
The data set to transform
data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data) reconstructed_data <- transformer$inverse_transform(transformed_data)
clone()
The objects of this class are cloneable with this method.
data_transformer$clone(deep = FALSE)
deep
Whether to make a deep clone.
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run) ## ------------------------------------------------ ## Method `data_transformer$fit` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) ## ------------------------------------------------ ## Method `data_transformer$transform` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data) ## ------------------------------------------------ ## Method `data_transformer$inverse_transform` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data) reconstructed_data <- transformer$inverse_transform(transformed_data)
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run) ## ------------------------------------------------ ## Method `data_transformer$fit` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) ## ------------------------------------------------ ## Method `data_transformer$transform` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data) ## ------------------------------------------------ ## Method `data_transformer$inverse_transform` ## ------------------------------------------------ data <- sample_toydata() transformer <- data_transformer$new() transformer$fit(data) transformed_data <- transformer$transform(data) reconstructed_data <- transformer$inverse_transform(transformed_data)
Provides a torch::nn_module with a simple deep convolutional neural net architecture, for use as the default architecture for image data in RGAN. Architecture inspired by: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
DCGAN_Discriminator( number_channels = 3, ndf = 64, dropout_rate = 0.5, sigmoid = FALSE )
DCGAN_Discriminator( number_channels = 3, ndf = 64, dropout_rate = 0.5, sigmoid = FALSE )
number_channels |
The number of channels in the image (RGB is 3 channels) |
ndf |
The number of feature maps in discriminator |
dropout_rate |
The dropout rate for each hidden layer |
sigmoid |
Switch between a sigmoid and linear output layer (the sigmoid is needed for the original GAN value function) |
A torch::nn_module for the DCGAN Discriminator
Provides a torch::nn_module with a simple deep convolutional neural net architecture, for use as the default architecture for image data in RGAN. Architecture inspired by: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
DCGAN_Generator( noise_dim = 100, number_channels = 3, ngf = 64, dropout_rate = 0.5 )
DCGAN_Generator( noise_dim = 100, number_channels = 3, ngf = 64, dropout_rate = 0.5 )
noise_dim |
The length of the noise vector per example |
number_channels |
The number of channels in the image (RGB is 3 channels) |
ngf |
The number of feature maps in generator |
dropout_rate |
The dropout rate for each hidden layer |
A torch::nn_module for the DCGAN Generator
Provides a torch::nn_module with a simple fully connected neural net, for use as the default architecture for tabular data in RGAN.
Discriminator( data_dim, hidden_units = list(128, 128), dropout_rate = 0.5, sigmoid = FALSE )
Discriminator( data_dim, hidden_units = list(128, 128), dropout_rate = 0.5, sigmoid = FALSE )
data_dim |
The number of columns in the data set |
A list of the number of neurons per layer, the length of the list determines the number of hidden layers |
|
dropout_rate |
The dropout rate for each hidden layer |
sigmoid |
Switch between a sigmoid and linear output layer (the sigmoid is needed for the original GAN value function) |
A torch::nn_module for the Discriminator
Provides a function that makes it easy to sample synthetic data from a Generator
expert_sample_synthetic_data(g_net, z, device, eval_dropout = FALSE)
expert_sample_synthetic_data(g_net, z, device, eval_dropout = FALSE)
g_net |
A torch::nn_module with a Generator |
z |
A noise vector |
device |
The device on which synthetic data should be sampled (cpu or cuda) |
eval_dropout |
Should dropout be applied during inference |
Synthetic data
Provides a function to quickly train a GAN model.
gan_trainer( data, noise_dim = 2, noise_distribution = "normal", value_function = "original", data_type = "tabular", generator = NULL, generator_optimizer = NULL, discriminator = NULL, discriminator_optimizer = NULL, base_lr = 1e-04, ttur_factor = 4, weight_clipper = NULL, batch_size = 50, epochs = 150, plot_progress = FALSE, plot_interval = "epoch", eval_dropout = FALSE, synthetic_examples = 500, plot_dimensions = c(1, 2), track_loss = FALSE, plot_loss = FALSE, device = "cpu" )
gan_trainer( data, noise_dim = 2, noise_distribution = "normal", value_function = "original", data_type = "tabular", generator = NULL, generator_optimizer = NULL, discriminator = NULL, discriminator_optimizer = NULL, base_lr = 1e-04, ttur_factor = 4, weight_clipper = NULL, batch_size = 50, epochs = 150, plot_progress = FALSE, plot_interval = "epoch", eval_dropout = FALSE, synthetic_examples = 500, plot_dimensions = c(1, 2), track_loss = FALSE, plot_loss = FALSE, device = "cpu" )
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
noise_distribution |
The noise distribution. Expects a function that samples from a distribution and returns a torch_tensor. For convenience "normal" and "uniform" will automatically set a function. Defaults to "normal". |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
data_type |
"tabular" or "image", controls the data type, defaults to "tabular". |
generator |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
generator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
discriminator |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
discriminator_optimizer |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
base_lr |
The base learning rate for the optimizers. Default is 0.0001. Only used if no optimizer is explicitly passed to the trainer. |
ttur_factor |
A multiplier for the learning rate of the discriminator, to implement the two time scale update rule. |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
epochs |
The number of training epochs. Defaults to 150. |
plot_progress |
Monitor training progress with plots. Defaults to FALSE. |
plot_interval |
Number of training steps between plots. Input number of steps or "epoch". Defaults to "epoch". |
eval_dropout |
Should dropout be applied during the sampling of synthetic data? Defaults to FALSE. |
synthetic_examples |
Number of synthetic examples that should be generated. Defaults to 500. For image data e.g. 16 would be more reasonable. |
plot_dimensions |
If you monitor training progress with a plot which dimensions of the data do you want to look at? Defaults to c(1, 2), i.e. the first two columns of the tabular data. |
track_loss |
Store the training losses as additional output. Defaults to FALSE. |
plot_loss |
Monitor the losses during training with plots. Defaults to FALSE. |
device |
Input on which device (e.g. "cpu", "cuda", or "mps") training should be done. Defaults to "cpu". |
gan_trainer trains the neural networks and returns an object of class trained_RGAN that contains the last generator, discriminator and the respective optimizers, as well as the settings.
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
GAN_update_plot(data, dimensions = c(1, 2), synth_data, epoch, main = NULL)
GAN_update_plot(data, dimensions = c(1, 2), synth_data, epoch, main = NULL)
data |
Real data to be plotted |
dimensions |
Which columns of the data should be plotted |
synth_data |
The synthetic data to be plotted |
epoch |
The epoch during training for the plot title |
main |
An optional plot title |
A function
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
GAN_update_plot_image(mfrow = c(4, 4), synth_data)
GAN_update_plot_image(mfrow = c(4, 4), synth_data)
mfrow |
The dimensions of the grid of images to be plotted |
synth_data |
The synthetic data (images) to be plotted |
A function
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
gan_update_step( data, batch_size, noise_dim, sample_noise, device = "cpu", g_net, d_net, g_optim, d_optim, value_function, weight_clipper, track_loss = FALSE )
gan_update_step( data, batch_size, noise_dim, sample_noise, device = "cpu", g_net, d_net, g_optim, d_optim, value_function, weight_clipper, track_loss = FALSE )
data |
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset. |
batch_size |
The number of training samples selected into the mini batch for training. Defaults to 50. |
noise_dim |
The dimensions of the GAN noise vector z. Defaults to 2. |
sample_noise |
A function to sample noise to a torch::tensor |
device |
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu". |
g_net |
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
d_net |
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network. |
g_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
d_optim |
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup |
value_function |
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original". |
weight_clipper |
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training. |
track_loss |
Store the training losses as additional output. Defaults to FALSE. |
A function
Implements the original GAN value function as a function to be called in gan_trainer. The function can serve as a template to implement new value functions in RGAN.
GAN_value_fct(real_scores, fake_scores)
GAN_value_fct(real_scores, fake_scores)
real_scores |
The discriminator scores on real examples ($D(x)$) |
fake_scores |
The discriminator scores on fake examples ($D(G(z))$) |
The function returns a named list with the entries d_loss and g_loss
Provides a torch::nn_module with a simple fully connected neural net, for use as the default architecture for tabular data in RGAN.
Generator( noise_dim, data_dim, hidden_units = list(128, 128), dropout_rate = 0.5 )
Generator( noise_dim, data_dim, hidden_units = list(128, 128), dropout_rate = 0.5 )
noise_dim |
The length of the noise vector per example |
data_dim |
The number of columns in the data set |
A list of the number of neurons per layer, the length of the list determines the number of hidden layers |
|
dropout_rate |
The dropout rate for each hidden layer |
A torch::nn_module for the Generator
Utility function for the kl WGAN loss for fake examples as described in https://arxiv.org/abs/1910.09779 and implemented in python in https://github.com/ermongroup/f-wgan.
kl_fake(dis_fake)
kl_fake(dis_fake)
dis_fake |
Discriminator scores on fake examples ($D(G(z))$) |
The loss
Utility function for the kl WGAN loss for Generator training as described in https://arxiv.org/abs/1910.09779 and implemented in python in https://github.com/ermongroup/f-wgan.
kl_gen(dis_fake)
kl_gen(dis_fake)
dis_fake |
Discriminator scores on fake examples ($D(G(z))$) |
The loss
Utility function for the kl WGAN loss for real examples as described in https://arxiv.org/abs/1910.09779 and implemented in python in https://github.com/ermongroup/f-wgan.
kl_real(dis_real)
kl_real(dis_real)
dis_real |
Discriminator scores on real examples ($D(x)$) |
The loss
Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
KLWGAN_value_fct(real_scores, fake_scores)
KLWGAN_value_fct(real_scores, fake_scores)
real_scores |
The discriminator scores on real examples ($D(x)$) |
fake_scores |
The discriminator scores on fake examples ($D(G(z))$) |
The function returns a named list with the entries d_loss and g_loss
Provides a function that makes it easy to sample synthetic data from a Generator
sample_synthetic_data(trained_gan, transformer = NULL)
sample_synthetic_data(trained_gan, transformer = NULL)
trained_gan |
A trained RGAN object of class "trained_RGAN" |
transformer |
The transformer object used to pre-process the data |
Function to sample from a
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
Sample Toydata to reproduce the examples in the paper.
sample_toydata(n = 1000, sd = 0.3, seed = 20211111)
sample_toydata(n = 1000, sd = 0.3, seed = 20211111)
n |
Number of observations to generate |
sd |
Standard deviation of the normal distribution to generate y |
seed |
A seed for the pseudo random number generator |
A matrix with two columns x and y
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
## Not run: # Before running the first time the torch backend needs to be installed torch::install_torch() # Load data data <- sample_toydata() # Build new transformer transformer <- data_transformer$new() # Fit transformer to data transformer$fit(data) # Transform data and store as new object transformed_data <- transformer$transform(data) # Train the default GAN trained_gan <- gan_trainer(transformed_data) # Sample synthetic data from the trained GAN synthetic_data <- sample_synthetic_data(trained_gan, transformer) # Plot the results GAN_update_plot(data = data, synth_data = synthetic_data, main = "Real and Synthetic Data after Training") ## End(Not run)
Provides a function to sample torch tensors from an arbitrary uniform distribution.
torch_rand_ab(shape, a = -1, b = 1, ...)
torch_rand_ab(shape, a = -1, b = 1, ...)
shape |
Vector of dimensions of resulting tensor |
a |
Lower bound of uniform distribution to sample from |
b |
Upper bound of uniform distribution to sample from |
... |
Potential additional arguments |
A sample from the specified uniform distribution in a tensor with the specified shape
Implements the Wasserstein GAN (WGAN) value function as a function to be called in gan_trainer. Note that for this to work properly you also need to implement a weight clipper (or other procedure) to constrain the Discriminator weights.
WGAN_value_fct(real_scores, fake_scores)
WGAN_value_fct(real_scores, fake_scores)
real_scores |
The discriminator scores on real examples ($D(x)$) |
fake_scores |
The discriminator scores on fake examples ($D(G(z))$) |
The function returns a named list with the entries d_loss and g_loss
A function that clips the weights of a Discriminator (for WGAN training).
WGAN_weight_clipper(d_net, clip_values = c(-0.01, 0.01))
WGAN_weight_clipper(d_net, clip_values = c(-0.01, 0.01))
d_net |
A torch::nn_module (typically a discriminator/critic) for which the weights should be clipped |
clip_values |
A vector with the lower and upper bound for weight values. Any value outside this range will be set to the closer value. |
The function modifies the torch::nn_module weights in place