Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Created May 13, 2019 22:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/741d795669d4f58ee146cc2937bcb501 to your computer and use it in GitHub Desktop.
Save sshleifer/741d795669d4f58ee146cc2937bcb501 to your computer and use it in GitHub Desktop.
"""Modified from https://github.com/gan3sh500/mixmatch-pytorch/blob/master/layer.py
Implementation of """
def mixmatch(X_labeled, y, X_unlabeled, model, augment_fn, T=0.5, K=2, alpha=0.75):
"""Generate labeled and unlabeled batches for mixmatch. Helpers are below. Use in dataloader."""
xb = augment_fn(X_labeled)
n_labeled = len(xb)
ub = [augment_fn(X_unlabeled) for _ in range(K)] # unlabeled
qb = sharpen(sum(map(model, ub)) / K, T)
C = np.concatenate
Ux = C(ub, axis=0)
Uy = C([qb for _ in range(K)], axis=0)
indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))
Wx = C([Ux, xb], axis=0)[indices]
Wy = C([qb, y], axis=0)[indices]
X, p = mixup(xb, Wx[:n_labeled], y, Wy[:n_labeled], alpha)
U, q = mixup(Ux, Wx[n_labeled:], Uy, Wy[n_labeled:], alpha)
return C([X, U], axis=1), C([p, q], axis=1), n_labeled
def sharpen(x, T):
numerator = x ** (1 / T)
return numerator / numerator.sum(axis=1, keepdims=True)
def lin_comb(a, b, frac_a): return (frac_a * a) + (1 - frac_a) * b
def mixup(x1, x2, y1, y2, alpha):
beta = np.random.beta(alpha, -alpha, x1.shape[0])
beta = np.maximum(beta, 1 - beta)
return lin_comb(x1, x2, beta), lin_comb(y1, y2, beta)
class MixMatchLoss(torch.nn.Module):
def __init__(self, lambda_u=100):
super().__init__()
self.lambda_u = lambda_u
def forward(self, preds, y, n_labeled):
labeled_loss = self.cross_entropy(preds[:n_labeled], y[:n_labeled])
unlabeled_loss = F.mse(preds[n_labeled:], y[n_labeled:])
return labeled_loss + (self.lambda_u * unlabeled_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment