Skip to content

Instantly share code, notes, and snippets.

@hakatashi
Created September 22, 2020 14:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hakatashi/6a7d9d03f013116f110c46527e5fac5c to your computer and use it in GitHub Desktop.
Save hakatashi/6a7d9d03f013116f110c46527e5fac5c to your computer and use it in GitHub Desktop.
TokyoWesterns CTF 2020 sqrt solver script
require "big"
record Mint, value : BigInt do
MOD = BigInt.new("6722156186149423473586056936189163112345526308304739592548269432948561498704906497631759731744824085311511299618196491816929603296108414569727189748975204102209646335725406551943711581704258725226874414399572244863268492324353927787818836752142254189928999592648333789131233670456465647924867060170327150559233")
ZERO = BigInt.new("0")
ONE = BigInt.new("1")
TWO = BigInt.new("2")
def self.factorial(n)
if @@factorials.empty?
@@factorials << self.new(1_i64)
end
@@factorials.size.upto(n) do |i|
@@factorials << @@factorials.last * i
end
@@factorials[n]
end
def self.permutation(n, k)
raise ArgumentError.new("k cannot be greater than n") unless n >= k
factorial(n) // factorial(n - k)
end
def self.combination(n, k)
raise ArgumentError.new("k cannot be greater than n") unless n >= k
permutation(n, k) // @@factorials[k]
end
def self.repeated_combination(n, k)
combination(n + k - 1, k)
end
def self.extended_gcd(a, b)
last_remainder, remainder = a.abs, b.abs
x, last_x, y, last_y = ZERO, ONE, ONE, ZERO
while remainder != 0
new_last_remainder = remainder
quotient, remainder = last_remainder.divmod(remainder)
last_remainder = new_last_remainder
x, last_x = last_x - quotient * x, x
y, last_y = last_y - quotient * y, y
end
return last_remainder, last_x * (a < 0 ? -1 : 1), last_y * (b < 0 ? -1 : 1)
end
def self.zero
self.new(0_i64)
end
def inv
g, x, _ = self.class.extended_gcd(@value, MOD)
self.class.new(x % MOD)
end
def +(value)
self.class.new((@value + value.to_i64 % MOD) % MOD)
end
def -(value)
self.class.new((@value + MOD - value.to_i64 % MOD) % MOD)
end
def *(value)
self.class.new((@value * value.to_i64 % MOD) % MOD)
end
def /(value : self)
raise DivisionByZeroError.new if value == 0
self * value.inv
end
def /(value)
raise DivisionByZeroError.new if value == 0
self * self.class.new(value.to_i64 % MOD).inv
end
def //(value)
self./(value)
end
def **(value)
b = value > 0 ? self : self.inv
e = value.abs
ret = self.class.new(ONE)
while e > 0
if e % 2 == 1
ret *= b
end
b *= b
e //= 2
end
ret
end
def <<(value)
self * self.class.new(TWO) ** value
end
def sqrt
z = self.class.new(ONE)
until z ** ((MOD - 1) // 2) == MOD - 1
z += 1
end
q = MOD - 1
m = 0
while q % 2 == 0
q //= 2
m += 1
end
c = z ** q
t = self ** q
r = self ** ((q + 1) // 2)
m.downto(2) do |i|
tmp = t ** (2 ** (i - 2))
if tmp != 1
r *= c
t *= c ** 2
end
c *= c
end
if r * r == self
{r, -r}
else
nil
end
end
def to_i64
@value
end
def ==(value : self)
@value == value.to_i64
end
def ==(value)
@value == value
end
def -
self.class.new(BigInt.zero) - self
end
def +
self
end
def abs
self
end
# ac-library compatibility
def pow(value)
self.**(value)
end
def val
self.to_i64
end
# ModInt shouldn't be compared
def <(value)
raise NotImplementedError.new("<")
end
def <=(value)
raise NotImplementedError.new("<=")
end
def <(value)
raise NotImplementedError.new("<")
end
def >=(value)
raise NotImplementedError.new(">=")
end
delegate to_s, to: @value
delegate inspect, to: @value
end
c = BigInt.new("5602276430032875007249509644314357293319755912603737631044802989314683039473469151600643674831915676677562504743413434940280819915470852112137937963496770923674944514657123370759858913638782767380945111493317828235741160391407042689991007589804877919105123960837253705596164618906554015382923343311865102111160")
mod = BigInt.new("6722156186149423473586056936189163112345526308304739592548269432948561498704906497631759731744824085311511299618196491816929603296108414569727189748975204102209646335725406551943711581704258725226874414399572244863268492324353927787818836752142254189928999592648333789131233670456465647924867060170327150559233")
a, n, m = Mint.extended_gcd(mod - 1, 1.to_big_i << 64)
def multi_sqrt(n, t)
if t == 0
return n
end
result = n.sqrt
unless result.nil?
a, b = result
return multi_sqrt(a, t - 1)
end
end
m30 = Mint.new(c) ** m
x = multi_sqrt(m30, 30).not_nil!
z = multi_sqrt(Mint.new(mod - 1), 29).not_nil!
p x ** (2.to_big_i ** 30) == m30
p x ** (2.to_big_i ** 64) == c
p z ** (2.to_big_i ** 29) == mod - 1
p z ** (2.to_big_i ** 30) == 1
instance = ARGV.size > 0 ? ARGV[0].to_i : 0
x *= z ** ((1_i64 << 28) * instance)
(1_i64 << 28).times do |i|
if i % 10000000 == 0
puts "#{instance} #{i}"
end
x *= z
if x.to_i64 >> 288 == 92733768484475_i64
puts "#{instance} #{x}"
exit
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment