Skip to content

Instantly share code, notes, and snippets.

@ymgve
Created November 6, 2017 02:12
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ymgve/2ff6471847effb845be94bf8c01911c1 to your computer and use it in GitHub Desktop.
Save ymgve/2ff6471847effb845be94bf8c01911c1 to your computer and use it in GitHub Desktop.
import socket, struct, os, binascii, base64, random, time, itertools
import telnetlib
def readline(sc, show = True):
res = ""
while len(res) == 0 or res[-1] != "\n":
data = sc.recv(1)
if len(data) == 0:
print repr(res)
raise Exception("Server disconnected")
res += data
if show:
print repr(res[:-1])
return res[:-1]
def read_until(sc, s):
res = ""
while not res.endswith(s):
data = sc.recv(1)
if len(data) == 0:
print repr(res)
raise Exception("Server disconnected")
res += data
return res[:-(len(s))]
def read_all(sc, n):
data = ""
while len(data) < n:
block = sc.recv(n - len(data))
if len(block) == 0:
print repr(data)
raise Exception("Server disconnected")
data += block
return data
def untempering(y):
y ^= (y >> 18)
y ^= (y << 15) & 0xefc60000
y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000)
y ^= (y >> 11) ^ (y >> 22)
return y
def tempering(y):
y ^= (y >> 11)
y ^= (y << 7) & 2636928640
y ^= (y << 15) & 4022730752
y ^= (y >> 18)
return y
def gen_cand(v, bits):
cand = []
if bits == 32:
cand.append(untempering(v))
elif bits == 28:
for i in xrange(16):
cand.append(untempering(v | (i << 28)))
elif bits == 24:
for i in xrange(256):
cand.append(untempering(v | (i << 24)))
elif bits == 20:
for i in xrange(4096):
cand.append(untempering(v | (i << 20)))
elif bits == 16:
for i in xrange(65536):
cand.append(untempering(v | (i << 16)))
return cand
class MTpred(object):
def __init__(self):
self.mt = [None] * 624
self.queue = []
self.comblimit = 131072
def gen_a624(self, idx):
a624 = self.mt[idx - 624]
if a624 is None:
return (0, 0x80000000)
else:
a624temp = set()
for v624 in a624:
a624temp.add(v624 & 0x80000000)
return a624temp
def calc_value0(self, idx):
a623 = self.mt[idx - 623]
a227 = self.mt[idx - 227]
a624 = self.gen_a624(idx)
if a623 != None and a227 != None:
if len(a623) * len(a227) <= self.comblimit:
cands = set()
for v624 in a624:
for v623 in a623:
for v227 in a227:
y = (v624 & 0x80000000) | (v623 & 0x7fffffff)
x = v227 ^ y >> 1
if y & 1 == 1:
x ^= 0x9908b0df
cands.add(x)
return cands
return None
def set_value(self, value, bits):
idx = len(self.mt)
cands = self.calc_value0(idx)
if cands == None:
if bits >= 24:
self.mt.append(gen_cand(value, bits))
else:
self.mt.append(None)
else:
mask = (1 << bits) - 1
cands2 = set()
for cand in cands:
if (tempering(cand) & mask) == (value & mask):
cands2.add(cand)
self.mt.append(list(cands2))
self.queue.append(idx)
while len(self.queue) > 0:
idx = self.queue.pop(0)
a624 = self.gen_a624(idx)
a623 = self.mt[idx - 623]
a227 = self.mt[idx - 227]
a0 = self.mt[idx]
changed = False
cands = self.calc_value0(idx)
if cands != None:
if a0 == None:
res = list(cands)
else:
res = []
for cand in a0:
if cand in cands:
res.append(cand)
if a0 == None or len(res) < len(a0):
self.mt[idx] = res
changed = True
if a623 != None and a0 != None:
if len(a623) * len(a0) <= self.comblimit:
cands = set()
for v624 in a624:
for v623 in a623:
for v0 in a0:
y = (v624 & 0x80000000) | (v623 & 0x7fffffff)
if y & 1 == 1:
v0 ^= 0x9908b0df
cands.add(v0 ^ (y >> 1))
if a227 == None:
res = list(cands)
else:
res = []
for cand in a227:
if cand in cands:
res.append(cand)
if a227 == None or len(res) < len(a227):
self.mt[idx - 227] = res
changed = True
if a227 != None and a0 != None:
if len(a227) * len(a0) <= self.comblimit:
cands = set()
for v227 in a227:
for extrabit in (0, 1):
for upperbit in (0, 0x80000000):
for v0 in a0:
t = v0
if extrabit == 1:
t ^= 0x9908b0df
t ^= v227
if t & 0x80000000:
continue
t = ((t << 1) | extrabit) ^ upperbit
cands.add(t)
if a623 == None:
res = list(cands)
else:
res = []
for cand in a623:
if cand in cands:
res.append(cand)
if a623 == None or len(res) < len(a623):
self.mt[idx - 623] = res
changed = True
if changed and len(self.queue) == 0:
if idx >= 624 + 227:
self.queue.append(idx - 227)
if idx >= 624 + 623:
self.queue.append(idx - 623)
if idx >= 624 + 624:
self.queue.append(idx - 624)
if idx + 1 < len(self.mt):
self.queue.append(idx + 1)
if idx + 397 < len(self.mt):
self.queue.append(idx + 397)
if idx + 624 < len(self.mt):
self.queue.append(idx + 624)
def gen_subsets(input, offset = 0):
half = len(input)
res = []
bits = {}
for i in xrange(0, 1 << half):
total = offset
for j in xrange(half):
if i & (1 << j) != 0:
total += input[j]
res.append(total)
bits[total] = i
return res, bits
def solve(numbers, target):
half = len(numbers) / 2
input1 = numbers[:half]
res1, bits1 = gen_subsets(input1)
set1 = set(res1)
input2 = [-x for x in numbers[half:]]
res2, bits2 = gen_subsets(input2, target)
set2 = set(res2)
common = list(set1.intersection(set2))[0]
answer = []
for i in xrange(half):
if bits1[common] & (1 << i) != 0:
answer.append(numbers[i])
for i in xrange(len(numbers)-half):
if bits2[common] & (1 << i) != 0:
answer.append(numbers[half+i])
return answer
def solve3(numbers, target, set1size, set2size):
start = time.time()
print "."
input1 = numbers[:set1size]
res1, bits1 = gen_subsets(input1)
print ".", len(res1)
input2 = numbers[set1size:set1size+set2size]
res2, bits2 = gen_subsets(input2)
print ".", len(res2)
input3 = [-x for x in numbers[set1size+set2size:]]
res3, bits3 = gen_subsets(input3, target)
print ".", len(res3)
for a in res1:
for b in res2:
if a + b in bits3:
print "FOUND"
print "elapsed", time.time() - start
answer = []
for i in xrange(set1size):
if bits1[a] & (1 << i) != 0:
answer.append(numbers[i])
for i in xrange(set2size):
if bits2[b] & (1 << i) != 0:
answer.append(numbers[set1size+i])
for i in xrange(len(numbers)-(set1size+set2size)):
if bits3[a + b] & (1 << i) != 0:
answer.append(numbers[set1size+set2size+i])
return answer
print "no answer, elapsed", time.time() - start
def consume_bits(t, bitsleft, bits):
bits = min(bitsleft, bits)
mask = ((1 << bits) - 1)
res = (t >> (bitsleft - bits)) & mask
return res, bitsleft - bits, bits
def task():
pred = MTpred()
# sc = socket.create_connection(("10.0.0.97", 12345))
sc = socket.create_connection(("54.92.67.18", 50216))
prob = 1
while prob <= 30:
read_until(sc, ": ")
res = readline(sc, False).split()
target = int(res[0])
arr = [int(x) for x in res[2:]]
print "problem %d size %d" % (prob, len(arr)),
n = 4 * prob + 7
assert n == len(arr)
for t in arr:
tt = abs(t)
bitsleft = min(4 * prob + 20, 120)
while bitsleft > 0:
res, bitsleft, bitsused = consume_bits(tt, bitsleft, 32)
pred.set_value(res, bitsused)
if t < 0:
pred.set_value(1, 1)
else:
pred.set_value(0, 1)
res = []
arr2 = []
target2 = target
savedstate = list(pred.mt)
for v in arr:
cands = pred.calc_value0(len(pred.mt))
pred.set_value(0, 0)
if cands is not None:
weight = [0, 0]
for cand in cands:
weight[tempering(cand) & 1] += 1
if weight[0] == 0 and weight[1] != 0:
res.append(v)
target2 -= v
elif weight[0] != 0 and weight[1] == 0:
pass
else:
arr2.append(v)
else:
arr2.append(v)
probsize = len(arr2)
print "SOLVE FOR NUMBER OF BITS", probsize
if probsize == 0:
res2 = []
elif probsize <= 32:
res2 = solve(arr2, target2)
elif probsize <= 43:
large = 18
small = (probsize - large) / 2
res2 = solve3(arr2, target2, small, probsize - large - small)
elif probsize <= 47:
large = 20
small = (probsize - large) / 2
res2 = solve3(arr2, target2, small, probsize - large - small)
elif probsize <= 52:
arr2 = arr2[:-6]
probsize = len(arr2)
large = 20
small = (probsize - large) / 2
res2 = solve3(arr2, target2, small, probsize - large - small)
if res2 is None:
print "no solution"
sc.close()
return
else:
print "problem too large"
sc.close()
return
res.extend(res2)
pred.mt = savedstate
res3 = []
for v in arr:
if v in res:
res3.append(v)
pred.set_value(1, 1)
else:
pred.set_value(0, 1)
sc.send(str(len(res3)) + " " + " ".join([str(x) for x in res3]) + "\n")
prob += 1
# t = telnetlib.Telnet()
# t.sock = sc
# t.interact()
while True:
data = sc.recv(16384)
if len(data) == 0:
exit()
for line in data.split("\n"):
print repr(line)
while True:
print "-----------------------------------------------------------------"
try:
task()
except:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment