Skip to content

josemanuel22/MMD_GAN.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MMD_GAN.jl

Build Status

Overview

MMD_GAN.jl is a Julia module implementing a Maximum Mean Discrepancy (MMD) Generative Adversarial Network. This module provides functionalities to train GAN models using MMD for measuring the discrepancy between the generated and real data distributions. It is designed for easy experimentation with different hyperparameters and model architectures.

Installation

To use MMD_GAN.jl, clone this repository into your local machine. Make sure you have Julia installed and set up on your system.

git clone [email protected]:josemanuel22/MMD_GAN.jl.git

Usage

The module includes core functionalities to define hyperparameters, set up models, and train an MMD GAN. Key components include:

  • HyperParamsMMD: A structure to define hyperparameters for the MMD GAN.
  • train_mmd_gan: A function to train the MMD GAN using specified encoder, decoder, and generator models with given hyperparameters.

Example

using MMD_GAN

# Define your models (encoder, decoder, generator)
# enc = ...
# dec = ...
# gen = ...

# Define hyperparameters
hparams = HyperParamsMMD(
    target_model = Normal(23.0f0, 1.0f0),
    noise_model = Normal(0.0f0, 1.0f0),
    # Other hyperparameters...
)

# Train the model
losses_gen, losses_dscr = train_mmd_gan(enc, dec, gen, hparams)

Contributing

Contributions to MMD_GAN.jl are welcome. Please read our contribution guidelines for more details.

License

This project is licensed under the MIT License.