Skip to content

Instantly share code, notes, and snippets.

@J3698
Created December 21, 2021 20:43
Show Gist options
  • Save J3698/d7c604f6ca715c71b8442b997e174d0a to your computer and use it in GitHub Desktop.
Save J3698/d7c604f6ca715c71b8442b997e174d0a to your computer and use it in GitHub Desktop.
order = sorted([str(i) for i in range(1098)])
chars = sorted(set(np.load("data_processed/dataY.npy")))
def fix_predictions(output):
outs = [chars[int(order[i.item()])] for i in torch.topk(output, 5, dim = 1).indices[0]]
return ["\\" + i.split("_")[1] for i in outs]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment