# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                              ])

# Download and load the training data`b
teacher_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
student_data = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)