Skip to content

Instantly share code, notes, and snippets.

@jsrimr
Created September 5, 2021 13:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jsrimr/3d0ad97ecc45286630e6efbaefe68f0f to your computer and use it in GitHub Desktop.
Save jsrimr/3d0ad97ecc45286630e6efbaefe68f0f to your computer and use it in GitHub Desktop.
function preservation
def _wider_conv(self, teacher_w1, teacher_b1, teacher_w2, width_coeff, verification):
new_width = int(width_coeff * teacher_w1.shape[3])
rand = np.random.randint(teacher_w1.shape[3], size=(new_width - teacher_w1.shape[3]))
replication_factor = np.bincount(rand)
student_w1 = teacher_w1.copy()
student_w2 = teacher_w2.copy()
student_b1 = teacher_b1.copy()
# target layer update (i)
for i in range(len(rand)):
teacher_index = rand[i]
new_weight = teacher_w1[:, :, :, teacher_index]
new_weight = new_weight[:, :, :, np.newaxis]
student_w1 = np.concatenate((student_w1, new_weight), axis=3)
student_b1 = np.append(student_b1, teacher_b1[teacher_index])
# next layer update (i+1)
for i in range(len(rand)):
teacher_index = rand[i]
factor = replication_factor[teacher_index] + 1
assert factor > 1, 'Error in Net2Wider'
new_weight = teacher_w2[:, :, teacher_index, :] * (1. / factor)
new_weight_re = new_weight[:, :, np.newaxis, :]
student_w2 = np.concatenate((student_w2, new_weight_re), axis=2)
student_w2[:, :, teacher_index, :] = new_weight
if verification:
import scipy.signal
inputs = np.random.rand(teacher_w1.shape[0] * 4, teacher_w1.shape[1] * 4, teacher_w1.shape[2])
ori1 = np.zeros((teacher_w1.shape[0] * 4, teacher_w1.shape[1] * 4, teacher_w1.shape[3]))
ori2 = np.zeros((teacher_w1.shape[0] * 4, teacher_w1.shape[1] * 4, teacher_w2.shape[3]))
new1 = np.zeros((teacher_w1.shape[0] * 4, teacher_w1.shape[1] * 4, student_w1.shape[3]))
new2 = np.zeros((teacher_w1.shape[0] * 4, teacher_w1.shape[1] * 4, student_w2.shape[3]))
for i in range(teacher_w1.shape[3]):
for j in range(inputs.shape[2]):
if j == 0:
tmp = scipy.signal.convolve2d(inputs[:, :, j], teacher_w1[:, :, j, i], mode='same')
else:
tmp += scipy.signal.convolve2d(inputs[:, :, j], teacher_w1[:, :, j, i], mode='same')
ori1[:, :, i] = tmp + teacher_b1[i]
for i in range(teacher_w2.shape[3]):
for j in range(ori1.shape[2]):
if j == 0:
tmp = scipy.signal.convolve2d(ori1[:, :, j], teacher_w2[:, :, j, i], mode='same')
else:
tmp += scipy.signal.convolve2d(ori1[:, :, j], teacher_w2[:, :, j, i], mode='same')
ori2[:, :, i] = tmp
for i in range(student_w1.shape[3]):
for j in range(inputs.shape[2]):
if j == 0:
tmp = scipy.signal.convolve2d(inputs[:, :, j], student_w1[:, :, j, i], mode='same')
else:
tmp += scipy.signal.convolve2d(inputs[:, :, j], student_w1[:, :, j, i], mode='same')
new1[:, :, i] = tmp + student_b1[i]
for i in range(student_w2.shape[3]):
for j in range(new1.shape[2]):
if j == 0:
tmp = scipy.signal.convolve2d(new1[:, :, j], student_w2[:, :, j, i], mode='same')
else:
tmp += scipy.signal.convolve2d(new1[:, :, j], student_w2[:, :, j, i], mode='same')
new2[:, :, i] = tmp
err = np.abs(np.sum(ori2 - new2))
assert err < self._error_th, 'Verification failed: [ERROR] {}'.format(err)
return student_w1, student_b1, student_w2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment