Skip to content

Instantly share code, notes, and snippets.

@eamartin
Created July 9, 2013 22:36
Show Gist options
  • Save eamartin/5961930 to your computer and use it in GitHub Desktop.
Save eamartin/5961930 to your computer and use it in GitHub Desktop.
class GSNCost(Cost):
def __init__(self, costs):
"""Costs is a list a (int, float, Cost) tuples
int : layer index to apply on
float : coefficient for cost
Cost : actual reconstruction cost for that layer
"""
self.costs = costs
def expr(self, model, data):
"""
data is a list of the same length as costs.
Each element is either a Theano tuple like or None.
The order for data is the same as the order for costs (ie
the first component of the data list becomes the layer described
in the first cost tuple in self.costs)
"""
indices = [c[0] for c in self.costs]
# pass sparse list to model
output = model.reconstruct(zip(indices, data))
# output is of an identical format to data, except no elements are None
cost = 0.0 # except actually make it a Theano tensor
for cost_idx, _, coeff, cost_obj in enumerate(self.costs):
if data[cost_idx] is not None:
cost += coeff * cost_obj(data[cost_idx], output[cost_idx])
return cost
def get_data_spec(self, model):
# what should go here??
@lamblin
Copy link

lamblin commented Jul 10, 2013

You would need to know in advance (when the GSNCost object is built) which costs you will include. It is not something that can depend on the data only.

Assuming you have a sparse list of indices to include in self.cost_indices, you could do something like:

    def get_data_specs(self, model):
        spaces = []
        sources = []
        for cost_idx, _, coeff, cost_obj in enumerate(self.costs):
            if cost_idx in self.cost_indices:
                space, source = cost_obj.get_data_specs(model)
                spaces.append(space)
                sources.append(source)
            else:
                spaces.append(NullSpace())
                sources.append('')
        return (CompositeSpace(spaces), tuple(sources))

The else part is not necessary if, instead of checking for if data[...] is not None in self.expr, you check for if cost_idx in self.cost_indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment