Skip to content

Instantly share code, notes, and snippets.

@Extraltodeus
Last active November 17, 2023 09:47
Show Gist options
  • Save Extraltodeus/0700821a3df907914994eb48036fc23e to your computer and use it in GitHub Desktop.
Save Extraltodeus/0700821a3df907914994eb48036fc23e to your computer and use it in GitHub Desktop.
the function that I mainly used to create my model
#input tensors are generally stacked layers from multiple models.
#autoweight_power will exponentially eliminate discrepancies. I used values from 2 to 16.
#proximity_power also powers the distance but at the moment of being evaluated, not after averaging. This reinforces clustered informations and can prevent discrepancies from being eliminated. Useful in some case but generally not recommanded.
#first model as ref options are better to be left alone lol. Use false.
#Also works with any tensor of course but conditionning is not that interesting with it.
def merge_tensors_by_consensus(tensors, autoweight_power, proximity_power, first_model_as_ref, first_model_as_ref_ignore, use_cuda=True):
if all(torch.equal(tensors[0], tensor) for tensor in tensors[1:]):
return tensors[0]
tensor_shape = tensors[0].shape
min_val = torch.full(tensor_shape, float("inf"))
max_val = torch.full(tensor_shape, float("-inf"))
min_val_cluster = torch.full(tensor_shape, float("inf"))
max_val_cluster = torch.full(tensor_shape, float("-inf"))
if use_cuda:
min_val = min_val.cuda()
max_val = max_val.cuda()
min_val_cluster = min_val_cluster.cuda()
max_val_cluster = max_val_cluster.cuda()
m_diff = []
mean_diff = []
if proximity_power > 1 and not first_model_as_ref:
for idx1, t1 in enumerate(tensors):
for idx2, t2 in enumerate(tensors):
if idx1 != idx2:
diff_tensor = torch.abs(torch.sub(t1, t2))
min_val_cluster = torch.minimum(min_val_cluster, diff_tensor)
max_val_cluster = torch.maximum(max_val_cluster, diff_tensor)
for idx1, t1 in enumerate(tensors):
if first_model_as_ref:
if idx1 != 0:
if not first_model_as_ref_ignore:
diff_tensor = torch.abs(torch.sub(tensors[0], t1))
else:
diff_tensor = torch.abs(t1)
else:
diff_tensor = torch.abs(torch.zeros_like(tensors[0]))
m_diff.append(diff_tensor)
min_val = torch.minimum(min_val, diff_tensor)
max_val = torch.maximum(max_val, diff_tensor)
else:
temp_diffs = []
for idx2, t2 in enumerate(tensors):
if idx1 != idx2:
diff_tensor = torch.abs(torch.sub(t1, t2))
if proximity_power > 1:
diff_tensor = torch.div(torch.sub(diff_tensor, min_val_cluster), torch.sub(max_val_cluster,min_val_cluster))
diff_tensor = torch.sub(1.0, diff_tensor)
diff_tensor = torch.pow(diff_tensor, proximity_power)
diff_tensor = torch.sub(1.0, diff_tensor)
temp_diffs.append(diff_tensor)
mean_diff = torch.mean(torch.stack(temp_diffs), dim=0)
m_diff.append(mean_diff)
min_val = torch.minimum(min_val, mean_diff)
max_val = torch.maximum(max_val, mean_diff)
del min_val_cluster, max_val_cluster
m_diff = [torch.div(torch.sub(diff_tensor, min_val), torch.sub(max_val, min_val)) for diff_tensor in m_diff]
m_diff = [torch.where(torch.isnan(tensor), torch.zeros_like(tensor), tensor) for tensor in m_diff]
m_diff = [torch.sub(1.0, diff_tensor) for diff_tensor in m_diff]
sum_m_diff = torch.sum(torch.stack(m_diff), dim=0, keepdim=True)
m_diff = [torch.div(val, sum_m_diff) for val in m_diff]
m_diff = [torch.pow(val, autoweight_power) for val in m_diff]
sum_m_diff = torch.sum(torch.stack(m_diff), dim=0)
m_diff = [torch.div(val, sum_m_diff) for val in m_diff]
result = torch.sum(torch.stack([torch.mul(m, tensor) for m, tensor in zip(m_diff, tensors)]), dim=0)[0]
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment