Skip to content

Instantly share code, notes, and snippets.

@kris-singh
Last active April 2, 2017 09:34
Show Gist options
  • Save kris-singh/d8798b858d1c1e87b2ffd32974cb4e76 to your computer and use it in GitHub Desktop.
Save kris-singh/d8798b858d1c1e87b2ffd32974cb4e76 to your computer and use it in GitHub Desktop.
template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class AdverserialLoss
{
public:
/**
* Create the AdverserialLoss object
*/
AdverserialLoss();
/*
* The forward function requires the 2 inputs
* the input data point and the generated point
* it computes the following [log D(x^(i)) + log (1- D(G(z^(i))))]
* if the value of the input_data is not provided the loss function acts as loss function
* for the generator network
* @param input_data Input data used for evaluating the specified function. D(x^i)
* @param input_generator Generator data used for evaluating the function (1- D(G(z^(i))))
*/
template<typename eT>
double Forward(arma::Mat<eT>&& input_generator, const arma::Mat<eT>&& input_data= arma::zeros(1))
{
return log(input_data) + log(input_generator);
}
/**
* Ordinary feed backward pass of a neural network.
*
* @param input_data The propagated input activation.
* @param input_generator The generator output
* @param output The calculated error.
*/
template<typename eT>
void Backward(const arma::Mat<eT>&& input_generator,
arma::Mat<eT>&& output,
const arma::Mat<eT>&& input = arma::zeros(1)
);
{
// Todo implement this
}
}; // cla
/*
Base class for GAN.
*/
template <typename Discriminator, typename Generator>
class GAN
{
public:
/*
The discriminator & generator loss function has to be of Adverserial loss.
The discriminator and the generator network last layers has to be of size 1.
@param Discriminator The Discriminator Network(Very similar to the) f
@param Generator can refer to Neural Network/ Function Approximator
@param Args Variadic Args Needed for other intialisations if any
*/
template <typename ... Args>
GAN(Discriminator discriminator, Generator generator, Args... args)
{
/* Intialise all the variable here */
}
/*
The train function of GAN.
The training proceeds in the altenative manner.
First we minimise the Discriminator network for k steps and then we
minimise the generator netwok. k=1 also work.
@param predictors Training data
@param responses Training Labels
@param k num of steps for training the discriminator
*/
template<typename eT, template<
template<typename> class OptimizerType = mlpack::optimization::StandardSGD>
Train(const arma::Mat<eT>& predictors,const size_t m, const size_t k)
{
/* Training Proceeds in Alternative Fashion */
for(size_t i = 0; i < k ; i++)
{
/* Train the Discriminator*/
TrainDiscriminator(predictors, m);
}
/*Optimisier function for the generator*/
TrainGenerator(m);
}
/*
The train function for the generator network.
We generate m samples from the Noise Distribution.
Then pass these samples to the train method for training
Rember the loss function for the generative network is the adverserial loss
@param: m : number of smaples to train for
*/
template<typename eT, template<
template<typename> class OptimizerType = mlpack::optimization::StandardSGD>
TrainGenerator(const size_t m)
{
/* Training Proceeds in Alternative Fashion */
arma::Mat<eT> generatedOutput;
for(size_t i=0; i<m; i++ )
{
/*Sample from the noise and store this in the generated output*/
generatedOutput(i) = noise.sample();
}
numFunctions = m;
Generator.Train(generatedOutput, m);
}
/*
The train discriminator network.
We generate m samples from the Noise Distribution and the actual distribution.
Then pass these samples to the train method for training
Rember the loss function for the adverserail network is the adverserial loss
@param predictors: samples from the true distribution
@param: m : number of smaples to train for
*/
template<typename eT, template<
template<typename> class OptimizerType = mlpack::optimization::StandardSGD>
TrainDiscriminator(const arma::Mat<eT>& predictors, const size_t m)
{
/* Training Proceeds in Alternative Fashion */
arma::Mat<eT> generatedOutput;
for(size_t i=0; i<m; i++ )
{
/*Sample from the noise and store this in the generated output*/
generatedOutput(i) = noise.sample();
}
numFunctions = 2*m;
arma::Mat<eT> data = join_cols(generatedOutput, predictors);
Discriminator.Train(data, m);
}
/*
Generate a sample using the Generator Network.
This should produce realistic outputs once the pe
network is trained
@param: noise noise which we use to generate the data
*/
template<typename Noise>
Generate(Noise noise)
{
/* Return generated output*/
generator.Generate(noise.Sample());
}
/**
* Evaluate the GAN network with the given parameters. This function
* is usually called by the optimizer to train the model.
* We also use the generators output along with index i to evaluate the
* function
* @param parameters Matrix model parameters.
* @param i Index of point to use for objective function evaluation.
* @param deterministic Whether or not to train or test the model. Note some
* layer act differently in training or testing mode.
*/
double Evaluate(const arma::mat& parameters,
const size_t i
const bool deterministic = true);
{
/*
Other Helper code that sets/ resets
parameters. Outlayer is the adverseraial Loss Layer
*/
/*Todo Implement This*/
}
/**
* Evaluate the gradient of the GAN network with the given parameters,
* and with respect to only one point in the dataset + Generators Outpur.
* This is useful for optimizers such as SGD, which require a separable objective function.
*
* @param parameters Matrix of the model parameters to be optimized.
* @param i Index of points to use for objective function gradient evaluation.
* @param gradient Matrix to output gradient into.
*/
void Gradient(const arma::mat& parameters,
const size_t i,
arma::mat& gradient)
{
/* Todo Implements This*/
}
size_t NumFunctions() const { return numFunctions; }
private:
// Helper functions.
/**
* The Forward algorithm (part of the Forward-Backward algorithm). Computes
* forward probabilities for each module. We call the forward function on
* input and and generated output
*
* @param input Data sequence to compute probabilities for.
*/
void Forward(const arma::mat& input)
{
/* call the forward function of the generetaor and discriminator*/
/* the forward visiotr of the generator takes 2 inputs*/
arma::mat response;
boost::apply_visitor(ForwardVisitor(
(ForwardVisitor(noise.sample), generator), input), discriminator);
}
/**
* The Backward algorithm (part of the Forward-Backward algorithm). Computes
* backward pass for module.
*/
void Backward()
{
/* Apply backward visitor*/
}
}
class StackedGAN
{
public:
/*
This is the base constructor for the stacked GAN class
@param gan: The generative adversial network to use
@param encoder: The encoder network to use
@param num_times: num of times to unroll the GAN(this should be equal to the number of
elemenetss in extraction_list)
@param extraction_list: list of layers that we use to extract the layers from
*/
template <typename Encoder>
StackedGAN(GAN& gan, Encoder encoder , const size_t num_times, aram::vec extraction_list)
{}
/*
Train method for stacked GAN's
Independent Training
Concat the noise and featurea vector from every feature layer.
for feeding the generator independent Training.
For the discriminator The input now becomes the features
and the generator input.
Train the the whole network using training the gan's one by one.
*/
void Train()
{
/*Todo implement this*/
}
/* Forward pass through every gan feeding the generators output of one
gan to other gan and also every gan forward pass is feed the
features extracted from each layer
*/
void Forward()
{
/*Todo implement this*/
}
/*
Backward pass throught the each
GANS
*/
void Backward()
{
/*Todo implement this*/
}
/*Evaluate function
Do a forward pass for the stacked network
*/
double Evaluate(const arma::mat& parameters,
const size_t i
const bool deterministic = true);
/**
* Evaluate the gradient of the stackedGAN network with the given parameters,
*
*/
void Gradient(const arma::mat& parameters,
const size_t i,
arma::mat& gradient)
{
/* Todo Implements This*/
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment