Created
April 14, 2014 21:49
-
-
Save katlogic/10685096 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdint.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
#include <string.h> | |
#include "crypto.h" | |
/************************************************* | |
* math lib | |
*************************************************/ | |
/* doubledigit_a = digit_b * digit_c | |
* write replacement using carry if not supported by your compiler */ | |
static inline void b_dmul(DIGIT *a, const DIGIT b, const DIGIT c) | |
{ | |
BIGDIGIT big = ((BIGDIGIT) b) * c; | |
a[1] = big >> DIGIT_BITS; | |
a[0] = big & MAX_DIGIT; | |
} | |
/* digit_a = doubledigit_b / digit_c | |
* write replacement if not supported by your compiler */ | |
static inline void b_ddiv(DIGIT *a, const DIGIT *b, const DIGIT c) | |
{ | |
BIGDIGIT big = ((BIGDIGIT)b[1] << DIGIT_BITS) | b[0]; | |
*a = big / c; | |
} | |
/* a = b + c */ | |
static inline DIGIT b_add(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits) | |
{ | |
DIGIT ai, carry; | |
unsigned int i; | |
carry = 0; | |
for (i = 0; i < digits; i++) { | |
if ((ai = b[i] + carry) < carry) | |
ai = c[i]; | |
else if ((ai += c[i]) < c[i]) | |
carry = 1; | |
else | |
carry = 0; | |
a[i] = ai; | |
} | |
return carry; | |
} | |
/* a = b - c */ | |
static inline DIGIT b_sub(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits) | |
{ | |
DIGIT ai, borrow; | |
unsigned int i; | |
borrow = 0; | |
for (i = 0; i < digits; i++) { | |
if ((ai = b[i] - borrow) > (MAX_DIGIT - borrow)) | |
ai = MAX_DIGIT - c[i]; | |
else if ((ai -= c[i]) > (MAX_DIGIT - c[i])) | |
borrow = 1; | |
else | |
borrow = 0; | |
a[i] = ai; | |
} | |
return borrow; | |
} | |
/* a = b + c * d */ | |
static inline DIGIT b_add_mul(DIGIT *a, const DIGIT *b, const DIGIT c, const DIGIT *d, const unsigned digits) | |
{ | |
DIGIT carry, t[2]; | |
unsigned int i; | |
if (c == 0) | |
return 0; | |
carry = 0; | |
for (i = 0; i < digits; i++) { | |
b_dmul(t, c, d[i]); | |
if ((a[i] = b[i] + carry) < carry) | |
carry = 1; | |
else | |
carry = 0; | |
if ((a[i] += t[0]) < t[0]) | |
carry++; | |
carry += t[1]; | |
} | |
return carry; | |
} | |
/* a = b * c */ | |
static void b_mul(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits) | |
{ | |
DIGIT t[2 * MAX_DIGITS]; | |
unsigned int bdigits, cdigits, i; | |
b_zero(t, 2 * digits); | |
bdigits = b_digits(b, digits); | |
cdigits = b_digits(c, digits); | |
for (i = 0; i < bdigits; i++) | |
t[i + cdigits] += | |
b_add_mul(&t[i], &t[i], b[i], c, cdigits); | |
b_copy(a, t, 2 * digits); | |
} | |
/* a = b << c, returns carry */ | |
static inline DIGIT b_shl(DIGIT *a, const DIGIT *b, const unsigned c, const unsigned digits) | |
{ | |
DIGIT bi, carry; | |
unsigned int i, t; | |
if (c >= DIGIT_BITS) | |
return 0; | |
t = DIGIT_BITS - c; | |
carry = 0; | |
for (i = 0; i < digits; i++) { | |
bi = b[i]; | |
a[i] = (bi << c) | carry; | |
carry = c ? (bi >> t) : 0; | |
} | |
return carry; | |
} | |
/* a = b >> c */ | |
static inline DIGIT b_shr(DIGIT *a, DIGIT *b, unsigned c, unsigned digits) | |
{ | |
DIGIT bi, carry; | |
int i; | |
unsigned int t; | |
if (c >= DIGIT_BITS) | |
return 0; | |
t = DIGIT_BITS - c; | |
carry = 0; | |
for (i = digits - 1; i >= 0; i--) { | |
bi = b[i]; | |
a[i] = (bi >> c) | carry; | |
carry = c?(bi << t):0; | |
} | |
return carry; | |
} | |
/* a = b - c*d */ | |
static inline DIGIT b_sub_mul(DIGIT * a, DIGIT * b, DIGIT c, DIGIT * d, | |
unsigned digits) | |
{ | |
DIGIT borrow, t[2]; | |
unsigned int i; | |
if (!c) | |
return 0; | |
borrow = 0; | |
for (i = 0; i < digits; i++) { | |
b_dmul(t, c, d[i]); | |
if ((a[i] = b[i] - borrow) > (MAX_DIGIT - borrow)) | |
borrow = 1; | |
else | |
borrow = 0; | |
if ((a[i] -= t[0]) > (MAX_DIGIT - t[0])) | |
borrow++; | |
borrow += t[1]; | |
} | |
return borrow; | |
} | |
/* sign of a - b */ | |
static inline int b_cmp(const DIGIT *a, const DIGIT *b, const unsigned digits) | |
{ | |
int i; | |
for (i = digits - 1; i >= 0; i--) { | |
if (a[i] > b[i]) | |
return 1; | |
if (a[i] < b[i]) | |
return -1; | |
} | |
return 0; | |
} | |
/* a = c div d */ | |
/* b = c mod d */ | |
static void b_div(DIGIT *a, DIGIT *b, const DIGIT *c, const unsigned cdigits, const DIGIT *d, const unsigned ddigits) | |
{ | |
DIGIT ai, cc[2 * MAX_DIGITS + 1], dd[MAX_DIGITS], t; | |
int i; | |
unsigned int dddigits, shift; | |
dddigits = b_digits(d, ddigits); | |
if (dddigits == 0) | |
return; | |
/* normalize */ | |
shift = DIGIT_BITS - b_digitbits(d[dddigits - 1]); | |
b_zero(cc, dddigits); | |
cc[cdigits] = b_shl(cc, c, shift, cdigits); | |
b_shl(dd, d, shift, dddigits); | |
t = dd[dddigits - 1]; | |
b_zero(a, cdigits); | |
for (i = cdigits - dddigits; i >= 0; i--) { | |
/* underestimate */ | |
if (t == MAX_DIGIT) | |
ai = cc[i + dddigits]; | |
else | |
b_ddiv(&ai, &cc[i + dddigits - 1], t + 1); | |
cc[i + dddigits] -= b_sub_mul(&cc[i], &cc[i], ai, dd, dddigits); | |
/* correct */ | |
while (cc[i + dddigits] || (b_cmp(&cc[i], dd, dddigits) >= 0)) { | |
ai++; | |
cc[i + dddigits] -= b_sub(&cc[i], &cc[i], dd, dddigits); | |
} | |
a[i] = ai; | |
} | |
b_zero(b, ddigits); | |
b_shr(b, cc, shift, dddigits); | |
} | |
/* a = b * c mod d */ | |
static inline void b_mod_mul(DIGIT *a, const DIGIT *b, const DIGIT *c, const DIGIT *d, const unsigned digits) | |
{ | |
DIGIT t[2 * MAX_DIGITS], tm[2*MAX_DIGITS]; | |
b_mul(t, b, c, digits); | |
b_div(tm, a, t, 2 * digits, d, digits); | |
} | |
/* a = b ^ c mod d */ | |
#define DIGIT_2MSB(x) (unsigned int)(((x) >> (DIGIT_BITS - 2)) & 3) | |
static void b_mod_exp(DIGIT *a, const DIGIT *b, const DIGIT *c, int cdigits, const DIGIT *d, const int ddigits) | |
{ | |
DIGIT bpower[3][MAX_DIGITS], ci, t[MAX_DIGITS]; | |
int i; | |
unsigned int ciBits, j, s; | |
b_copy(bpower[0], b, ddigits); | |
b_mod_mul(bpower[1], bpower[0], b, d, ddigits); /* b^2 mod d */ | |
b_mod_mul(bpower[2], bpower[1], b, d, ddigits); /* b^3 mod d */ | |
b_zero(t, ddigits); | |
t[0] = 1; | |
cdigits = b_digits(c, cdigits); | |
for (i = cdigits - 1; i >= 0; i--) { | |
ci = c[i]; | |
ciBits = DIGIT_BITS; | |
/* lesser-most bit */ | |
if (i == (int) (cdigits - 1)) { | |
while (!DIGIT_2MSB(ci)) { | |
ci <<= 2; | |
ciBits -= 2; | |
} | |
} | |
for (j = 0; j < ciBits; j += 2, ci <<= 2) { | |
/* t= t^4 * b^(msbof ci) mod d */ | |
b_mod_mul(t, t, t, d, ddigits); | |
b_mod_mul(t, t, t, d, ddigits); | |
if ((s = DIGIT_2MSB(ci)) != 0) | |
b_mod_mul(t, t, bpower[s - 1], d, ddigits); | |
} | |
} | |
b_copy(a, t, ddigits); | |
} | |
/************************************************* | |
* RSA functions | |
*************************************************/ | |
/* public key operation */ | |
static int do_rsa_public(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *pub) | |
{ | |
DIGIT c[MAX_DIGITS], m[MAX_DIGITS]; | |
int ndigits, edigits; | |
b_decode(m, MAX_DIGITS, input, inlen); | |
bdump(m); | |
ndigits = b_digits(pub->n, MAX_DIGITS); | |
edigits = b_digits(pub->e, MAX_DIGITS); | |
if (b_cmp(m, pub->n, ndigits) >= 0) | |
return -1; | |
b_mod_exp(c, m, pub->e, edigits, pub->n, ndigits); | |
bdump(c); | |
b_encode(output, (pub->bits + 7) / 8, c, ndigits); | |
return (pub->bits + 7) / 8; | |
} | |
/* private key operation */ | |
static int do_rsa_private(u8 *output, const u8 *input, const int inlen, const RSA *priv) | |
{ | |
DIGIT c[MAX_DIGITS], cP[MAX_DIGITS], cQ[MAX_DIGITS], | |
mP[MAX_DIGITS], mQ[MAX_DIGITS], t[MAX_DIGITS*2]; | |
int cdigits, ndigits, pdigits; | |
b_decode(c, MAX_DIGITS, input, inlen); | |
cdigits = b_digits(c, MAX_DIGITS); | |
ndigits = b_digits(priv->pub.n, MAX_DIGITS); | |
pdigits = b_digits(priv->p, MAX_DIGITS); | |
if (b_cmp(c, priv->pub.n, ndigits) >= 0) | |
return -1; | |
/* mP = cP^dp mod p, mQ = cQ^dQ mod q */ | |
b_div(t, cP, c, cdigits, priv->p, pdigits); | |
b_div(t, cQ, c, cdigits, priv->q, pdigits); | |
b_mod_exp(mP, cP, priv->dP, pdigits, priv->p, pdigits); | |
b_zero(mQ, ndigits); | |
b_mod_exp(mQ, cQ, priv->dQ, pdigits, priv->q, pdigits); | |
/* chinese theorem: m = ((((mP - mQ) mod p) * qInv) mod p) * q + mQ. */ | |
if (b_cmp(mP, mQ, pdigits) >= 0) | |
b_sub(t, mP, mQ, pdigits); | |
else { | |
b_sub(t, mQ, mP, pdigits); | |
b_sub(t, priv->p, t, pdigits); | |
} | |
b_mod_mul(t, t, priv->qInv, priv->p, pdigits); | |
b_mul(t, t, priv->q, pdigits); | |
b_add(t, t, mQ, ndigits); | |
bdump(t); | |
b_encode(output, (priv->pub.bits + 7) / 8, t, ndigits); | |
return (priv->pub.bits + 7) / 8; | |
} | |
/* point is to prevent zero bytes .. */ | |
static inline u8 nzrnd() | |
{ | |
u8 b = 0; | |
while (!b) rand_bytes(&b, 1); | |
return b; | |
} | |
/************************************************* | |
* visible RSA functions | |
*************************************************/ | |
/* decode from pkcs stream (b) to machine order (a) */ | |
void b_decode(DIGIT *a, unsigned digits, const u8 *b, const int len) | |
{ | |
DIGIT t; | |
int j; | |
unsigned int i, u; | |
for (i = 0, j = len - 1; i < digits && j >= 0; i++) { | |
t = 0; | |
for (u = 0; j >= 0 && u < DIGIT_BITS; j--, u += 8) | |
t |= ((DIGIT) b[j]) << u; | |
a[i] = t; | |
} | |
for (; i < digits; i++) | |
a[i] = 0; | |
} | |
/* decode from machine order (b) to pkcs order (a) */ | |
void b_encode(u8 *a, int len, const DIGIT *b, const unsigned digits) | |
{ | |
DIGIT t; | |
int j; | |
unsigned int i, u; | |
for (i = 0, j = len - 1; i < digits && j >= 0; i++) { | |
t = b[i]; | |
for (u = 0; j >= 0 && u < DIGIT_BITS; j--, u += 8) | |
a[j] = (unsigned char) (t >> u); | |
} | |
for (; j >= 0; j--) | |
a[j] = 0; | |
} | |
int rsa_sign(u8 *output, const u8 *input, const int inlen, const RSA *key) | |
{ | |
u8 pkcs[MAX_RSA_BYTES]; | |
int i, mlen = (key->pub.bits + 7) / 8; | |
if (inlen + 11 > mlen) | |
return -1; | |
pkcs[0] = 0; pkcs[1] = 1; | |
for (i = 2; i < mlen - inlen - 1; i++) | |
pkcs[i] = 0xff; | |
pkcs[i++] = 0; | |
memcpy(pkcs + i, input, inlen); | |
return do_rsa_private(output, pkcs, mlen, key); | |
} | |
int rsa_encrypt(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *key) | |
{ | |
u8 pkcs[MAX_RSA_BYTES]; | |
int i, mlen = (key->bits + 7) / 8; | |
if (inlen + 11 > mlen) | |
return -1; | |
pkcs[0] = 0; pkcs[1] = 2; | |
for (i = 2; i < mlen - inlen - 1; i++) | |
pkcs[i] = nzrnd(); | |
pkcs[i++] = 0; | |
memcpy(pkcs + i, input, inlen); | |
return do_rsa_public(output, pkcs, mlen, key); | |
} | |
int rsa_verify(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *key) | |
{ | |
u8 pkcs[MAX_RSA_BYTES]; | |
int pkcslen; | |
int ret, i, mlen = (key->bits + 7) / 8; | |
if (inlen > mlen) | |
return -1; | |
if ((pkcslen = do_rsa_public(pkcs, input, inlen, key)) < 0) | |
return pkcslen; | |
if (pkcslen != mlen) | |
return -2; | |
if (pkcs[0] || pkcs[1] != 1) | |
return -3; | |
for (i = 2; i < mlen-1; i++) | |
if (!pkcs[i]) | |
break; | |
ret = mlen - i - 1; | |
if (ret + 11 > mlen) | |
return -3; | |
memcpy(output, pkcs + i + 1, ret); | |
return ret; | |
} | |
int rsa_decrypt(u8 *output, const u8 *input, const int inlen, const RSA *key) | |
{ | |
u8 pkcs[MAX_RSA_BYTES]; | |
int pkcslen; | |
int ret, i, mlen = (key->pub.bits + 7) / 8; | |
if (inlen > mlen) | |
return -1; | |
if ((pkcslen = do_rsa_private(pkcs, input, inlen, key)) < 0) | |
return -1; | |
if (pkcslen != mlen) | |
return -2; | |
if (pkcs[0] || pkcs[1] != 2) | |
return -3; | |
/* find terminator */ | |
for (i = 2; i < mlen-1; i++) | |
if (!pkcs[i]) | |
break; | |
ret = mlen - i - 1; | |
if (ret + 11 > mlen) | |
return -3; | |
memcpy(output, pkcs + i + 1, ret); | |
return ret; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment