Skip to content

Instantly share code, notes, and snippets.

@SeverTopan
Created September 18, 2020 00:01
Show Gist options
  • Save SeverTopan/5a88c69cc845390ca701f54ae600baf8 to your computer and use it in GitHub Desktop.
Save SeverTopan/5a88c69cc845390ca701f54ae600baf8 to your computer and use it in GitHub Desktop.
3SAT to MAX2SAT
def _3_sat_to_max_2_sat(S):
# Using https://math.stackexchange.com/questions/1633005/how-exactly-does-a-max-2-sat-reduce-to-a-3-sat
S = torch.FloatTensor(S)
S_prime = torch.zeros([S.size()[0]*10, S.size()[1] + S.size()[0]])
num_aux = len(S)
for aux, (t, l1, l2, l3) in enumerate(S):
aux_true = [0]*num_aux
aux_true[aux] = 1
aux_false = [0]*num_aux
aux_true[aux] = -1
aux_absent = [0]*num_aux
clauses = torch.FloatTensor([
[-1, l1, 0, 0] + aux_absent,
[-1, 0, l2, 0] + aux_absent,
[-1, 0, 0, l3] + aux_absent,
[-1, 0, 0, 0] + aux_true,
[-1, -l1, -l2, 0] + aux_absent,
[-1, 0, -l2, -l3] + aux_absent,
[-1, -l1, 0, -l3] + aux_absent,
[-1, l1, 0, 0] + aux_false,
[-1, 0, l2, 0] + aux_false,
[-1, 0, 0, l3] + aux_false,
])
clauses[:4] *= 1/math.sqrt(4*3)
clauses[4:] *= 1/math.sqrt(4*4)
S_prime[aux*10:(aux + 1)*10] = clauses
return S_prime
# CNF -- works for all 2-input functions aside from XOR
S = _3_sat_to_max_2_sat([
[-1, 1, 1, -1],
[-1, 1, -1, 1],
[-1, -1, 1, 1],
[-1, -1, -1, 1],
])
model = satnet.SATNet(3, S.size()[0], S.size()[1] - 4, prox_lam=1e-1, eps=1e-4, max_iter=100)
model.S = torch.nn.Parameter(S.t())
# model = satnet.SATNet(3, 8, 1, prox_lam=1e-1)
# model.load_state_dict(torch.load('/data/logs/parity.aux1-m8-lr0.1-bsz100/it2.pth'))
for x in [
torch.FloatTensor([[0., 0., 0.]]),
torch.FloatTensor([[0., 1., 0.]]),
torch.FloatTensor([[1., 0., 0.]]),
torch.FloatTensor([[1., 1., 0.]]),
]:
y = model(x, torch.IntTensor([[1, 1, 0]]))
print(f'x: {x} == {y}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment