Last active
March 13, 2019 12:04
-
-
Save scf37/f6506e0bf8eafdf5e1495587340692d5 to your computer and use it in GitHub Desktop.
Encryption in field [0..10**n) based on Feistel network.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package digcrypt | |
import java.nio.ByteBuffer | |
import java.util.Random | |
import javax.crypto.Mac | |
import javax.crypto.spec.SecretKeySpec | |
/** | |
* Digit encryption. | |
* | |
* Transforms non-negative decimal numbers of n digits into the same space of n digits. | |
* I.e. for any i in range 0 <= i < 10**n it will be encrypted into j in the same range. | |
* Useful for encryption of pin codes, public incremental ids (to hide total count of rows), etc. | |
* | |
* Security. | |
* | |
* This cipher is secure as long as encryption transformation used is secure. | |
* Encryption transformation need not to be reversible. | |
* Recommended approach is to use first 8 bytes of HMAC hash of input. | |
* | |
*/ | |
object DigCrypt { | |
/** | |
* Number of decimal digits Long can contain | |
*/ | |
val maxN = Math.log10(Long.MaxValue).toInt // 18 | |
/** | |
* Number of rounds in Feistel network. | |
* Minimum is 3 but even number of rounds must be used for odd n | |
*/ | |
private val rounds = 4 | |
private val pow10: Array[Long] = (0 to maxN).map(n => (0 until n).foldLeft(1L)((r, _) => r * 10)).toArray | |
private def left(l: Long, n: Int): Long = l / pow10(n / 2) | |
private def right(l: Long, n: Int): Long = l % pow10(n / 2) | |
private def leftSize(n: Int): Int = n - n / 2 | |
private def rightSize(n: Int): Int = n / 2 | |
private def mod(l: Long, n: Int): Long = { | |
val r = l % pow10(n) | |
if (r < 0) r + pow10(n) else r | |
} | |
private def combine(l: Long, r: Long, lsz: Int, rsz: Int) = mod(l, lsz) * pow10(rsz) + mod(r, rsz) | |
/** | |
* Encrypt number v keeping n digits | |
* | |
* @param v number to encrypt | |
* @param n count of digits we use, v < n**10 | |
* @param enc strong encryption transformation, not necessary reversible | |
* @return encrypted number 0 <= result < n**10 | |
*/ | |
def encrypt(v: Long, n: Int, enc: Long => Long): Long = { | |
require(v >= 0, "value must be non-negative") | |
require(v < pow10(n), "value must fit to n digits") | |
require(n > 1, "n must be at least 2") | |
require(n <= maxN, "n must be not more than " + maxN) | |
var l = left(v, n) | |
var r = right(v, n) | |
var lsz = leftSize(n) | |
var rsz = rightSize(n) | |
var i = rounds | |
while (i != 0) { | |
r += enc(l) | |
r = mod(r, rsz) | |
val m = l | |
l = r | |
r = m | |
val mm = rsz | |
rsz = lsz | |
lsz = mm | |
i -= 1 | |
} | |
combine(l, r, lsz, rsz) | |
} | |
/** | |
* Decrypt number encrypted by encrypt function | |
* | |
* @param v encrypted value | |
* @param n count of digits used, same as passed to encrypt | |
* @param enc strong encryption transformation, same as used in encrypt | |
* @return decrypted value | |
*/ | |
def decrypt(v: Long, n: Int, enc: Long => Long): Long = { | |
require(v >= 0, "value must be non-negative") | |
require(v < pow10(n), "value must fit to n digits") | |
require(n > 1, "n must be at least 2") | |
require(n <= maxN, "n must be not more than " + maxN) | |
var l = left(v, n) | |
var r = right(v, n) | |
var lsz = leftSize(n) | |
var rsz = rightSize(n) | |
var i = rounds | |
while (i != 0) { | |
val mm = rsz | |
rsz = lsz | |
lsz = mm | |
val m = l | |
l = r | |
r = m | |
r -= enc(l) | |
r = mod(r, rsz) | |
i -= 1 | |
} | |
combine(l, r, lsz, rsz) | |
} | |
/** | |
* Make strong encryption transformation from given secret key | |
* | |
* @param key key to use | |
* @return encryption transformation for encrypt/decrypt functions | |
*/ | |
def makeEnc(key: Array[Byte]): Long => Long = { | |
val sha512_HMAC = Mac.getInstance("HmacSHA256") | |
val keySpec = new SecretKeySpec(key, "HmacSHA256") | |
val buf = new Array[Byte](8) | |
val sync = new Object | |
v => sync.synchronized { | |
sha512_HMAC.init(keySpec) | |
long2Array(v, buf) | |
val data = sha512_HMAC.doFinal(buf) | |
array2Long(data) | |
} | |
} | |
private def long2Array(l: Long, arr: Array[Byte]): Unit = { | |
val buffer: ByteBuffer = ByteBuffer.wrap(arr) | |
buffer.putLong(l) | |
buffer.array() | |
} | |
private def array2Long(arr: Array[Byte]): Long = { | |
val buffer: ByteBuffer = ByteBuffer.wrap(arr) | |
buffer.getLong | |
} | |
def main(args: Array[String]): Unit = { | |
tests() | |
val e: Long => Long = makeEnc("secret".getBytes) | |
println("i\tenc\t\tdec") | |
for (i <- 0 to 20){ | |
val v = encrypt(i, 5, e) | |
println(i + "\t" + v + "\t" + decrypt(v, 5, e)) | |
} | |
} | |
def tests(): Unit = { | |
require(pow10(2) == 100L) | |
require(pow10(maxN) == 1000000000000000000L) | |
require(combine(111111111L, 2222222222L, 3, 2) == 11122L) | |
require(combine(111111111L, 2222222222L, 3, 3) == 111222L) | |
require(mod(123, 2) == 23) | |
require(mod(23, 2) == 23) | |
require(mod(3, 2) == 3) | |
require(mod(-3, 2) == -3 + 100) | |
val e: Long => Long = makeEnc("hello".getBytes) | |
for (i <- 0 to 99) { | |
val v = encrypt(i, 2, e) | |
require(i == decrypt(v, 2, e), i.toString) | |
} | |
for (i <- 0 to 9999) { | |
val v = encrypt(i, 4, e) | |
require(i == decrypt(v, 4, e), i.toString) | |
} | |
for (i <- 0 to 99999) { | |
val v = encrypt(i, 5, e) | |
require(i == decrypt(v, 5, e), i.toString) | |
} | |
val r = new Random() | |
for (_ <- 0 to 100000) { | |
val i = Math.abs(r.nextLong()) % pow10(maxN) | |
val v = encrypt(i, maxN, e) | |
require(i == decrypt(v, maxN, e), i.toString) | |
} | |
println("tests OK") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment