Skip to content

Instantly share code, notes, and snippets.

@katlogic
Created April 14, 2014 21:49
Show Gist options
  • Save katlogic/10685096 to your computer and use it in GitHub Desktop.
Save katlogic/10685096 to your computer and use it in GitHub Desktop.
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "crypto.h"
/*************************************************
* math lib
*************************************************/
/* doubledigit_a = digit_b * digit_c
* write replacement using carry if not supported by your compiler */
static inline void b_dmul(DIGIT *a, const DIGIT b, const DIGIT c)
{
BIGDIGIT big = ((BIGDIGIT) b) * c;
a[1] = big >> DIGIT_BITS;
a[0] = big & MAX_DIGIT;
}
/* digit_a = doubledigit_b / digit_c
* write replacement if not supported by your compiler */
static inline void b_ddiv(DIGIT *a, const DIGIT *b, const DIGIT c)
{
BIGDIGIT big = ((BIGDIGIT)b[1] << DIGIT_BITS) | b[0];
*a = big / c;
}
/* a = b + c */
static inline DIGIT b_add(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits)
{
DIGIT ai, carry;
unsigned int i;
carry = 0;
for (i = 0; i < digits; i++) {
if ((ai = b[i] + carry) < carry)
ai = c[i];
else if ((ai += c[i]) < c[i])
carry = 1;
else
carry = 0;
a[i] = ai;
}
return carry;
}
/* a = b - c */
static inline DIGIT b_sub(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits)
{
DIGIT ai, borrow;
unsigned int i;
borrow = 0;
for (i = 0; i < digits; i++) {
if ((ai = b[i] - borrow) > (MAX_DIGIT - borrow))
ai = MAX_DIGIT - c[i];
else if ((ai -= c[i]) > (MAX_DIGIT - c[i]))
borrow = 1;
else
borrow = 0;
a[i] = ai;
}
return borrow;
}
/* a = b + c * d */
static inline DIGIT b_add_mul(DIGIT *a, const DIGIT *b, const DIGIT c, const DIGIT *d, const unsigned digits)
{
DIGIT carry, t[2];
unsigned int i;
if (c == 0)
return 0;
carry = 0;
for (i = 0; i < digits; i++) {
b_dmul(t, c, d[i]);
if ((a[i] = b[i] + carry) < carry)
carry = 1;
else
carry = 0;
if ((a[i] += t[0]) < t[0])
carry++;
carry += t[1];
}
return carry;
}
/* a = b * c */
static void b_mul(DIGIT *a, const DIGIT *b, const DIGIT *c, const unsigned digits)
{
DIGIT t[2 * MAX_DIGITS];
unsigned int bdigits, cdigits, i;
b_zero(t, 2 * digits);
bdigits = b_digits(b, digits);
cdigits = b_digits(c, digits);
for (i = 0; i < bdigits; i++)
t[i + cdigits] +=
b_add_mul(&t[i], &t[i], b[i], c, cdigits);
b_copy(a, t, 2 * digits);
}
/* a = b << c, returns carry */
static inline DIGIT b_shl(DIGIT *a, const DIGIT *b, const unsigned c, const unsigned digits)
{
DIGIT bi, carry;
unsigned int i, t;
if (c >= DIGIT_BITS)
return 0;
t = DIGIT_BITS - c;
carry = 0;
for (i = 0; i < digits; i++) {
bi = b[i];
a[i] = (bi << c) | carry;
carry = c ? (bi >> t) : 0;
}
return carry;
}
/* a = b >> c */
static inline DIGIT b_shr(DIGIT *a, DIGIT *b, unsigned c, unsigned digits)
{
DIGIT bi, carry;
int i;
unsigned int t;
if (c >= DIGIT_BITS)
return 0;
t = DIGIT_BITS - c;
carry = 0;
for (i = digits - 1; i >= 0; i--) {
bi = b[i];
a[i] = (bi >> c) | carry;
carry = c?(bi << t):0;
}
return carry;
}
/* a = b - c*d */
static inline DIGIT b_sub_mul(DIGIT * a, DIGIT * b, DIGIT c, DIGIT * d,
unsigned digits)
{
DIGIT borrow, t[2];
unsigned int i;
if (!c)
return 0;
borrow = 0;
for (i = 0; i < digits; i++) {
b_dmul(t, c, d[i]);
if ((a[i] = b[i] - borrow) > (MAX_DIGIT - borrow))
borrow = 1;
else
borrow = 0;
if ((a[i] -= t[0]) > (MAX_DIGIT - t[0]))
borrow++;
borrow += t[1];
}
return borrow;
}
/* sign of a - b */
static inline int b_cmp(const DIGIT *a, const DIGIT *b, const unsigned digits)
{
int i;
for (i = digits - 1; i >= 0; i--) {
if (a[i] > b[i])
return 1;
if (a[i] < b[i])
return -1;
}
return 0;
}
/* a = c div d */
/* b = c mod d */
static void b_div(DIGIT *a, DIGIT *b, const DIGIT *c, const unsigned cdigits, const DIGIT *d, const unsigned ddigits)
{
DIGIT ai, cc[2 * MAX_DIGITS + 1], dd[MAX_DIGITS], t;
int i;
unsigned int dddigits, shift;
dddigits = b_digits(d, ddigits);
if (dddigits == 0)
return;
/* normalize */
shift = DIGIT_BITS - b_digitbits(d[dddigits - 1]);
b_zero(cc, dddigits);
cc[cdigits] = b_shl(cc, c, shift, cdigits);
b_shl(dd, d, shift, dddigits);
t = dd[dddigits - 1];
b_zero(a, cdigits);
for (i = cdigits - dddigits; i >= 0; i--) {
/* underestimate */
if (t == MAX_DIGIT)
ai = cc[i + dddigits];
else
b_ddiv(&ai, &cc[i + dddigits - 1], t + 1);
cc[i + dddigits] -= b_sub_mul(&cc[i], &cc[i], ai, dd, dddigits);
/* correct */
while (cc[i + dddigits] || (b_cmp(&cc[i], dd, dddigits) >= 0)) {
ai++;
cc[i + dddigits] -= b_sub(&cc[i], &cc[i], dd, dddigits);
}
a[i] = ai;
}
b_zero(b, ddigits);
b_shr(b, cc, shift, dddigits);
}
/* a = b * c mod d */
static inline void b_mod_mul(DIGIT *a, const DIGIT *b, const DIGIT *c, const DIGIT *d, const unsigned digits)
{
DIGIT t[2 * MAX_DIGITS], tm[2*MAX_DIGITS];
b_mul(t, b, c, digits);
b_div(tm, a, t, 2 * digits, d, digits);
}
/* a = b ^ c mod d */
#define DIGIT_2MSB(x) (unsigned int)(((x) >> (DIGIT_BITS - 2)) & 3)
static void b_mod_exp(DIGIT *a, const DIGIT *b, const DIGIT *c, int cdigits, const DIGIT *d, const int ddigits)
{
DIGIT bpower[3][MAX_DIGITS], ci, t[MAX_DIGITS];
int i;
unsigned int ciBits, j, s;
b_copy(bpower[0], b, ddigits);
b_mod_mul(bpower[1], bpower[0], b, d, ddigits); /* b^2 mod d */
b_mod_mul(bpower[2], bpower[1], b, d, ddigits); /* b^3 mod d */
b_zero(t, ddigits);
t[0] = 1;
cdigits = b_digits(c, cdigits);
for (i = cdigits - 1; i >= 0; i--) {
ci = c[i];
ciBits = DIGIT_BITS;
/* lesser-most bit */
if (i == (int) (cdigits - 1)) {
while (!DIGIT_2MSB(ci)) {
ci <<= 2;
ciBits -= 2;
}
}
for (j = 0; j < ciBits; j += 2, ci <<= 2) {
/* t= t^4 * b^(msbof ci) mod d */
b_mod_mul(t, t, t, d, ddigits);
b_mod_mul(t, t, t, d, ddigits);
if ((s = DIGIT_2MSB(ci)) != 0)
b_mod_mul(t, t, bpower[s - 1], d, ddigits);
}
}
b_copy(a, t, ddigits);
}
/*************************************************
* RSA functions
*************************************************/
/* public key operation */
static int do_rsa_public(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *pub)
{
DIGIT c[MAX_DIGITS], m[MAX_DIGITS];
int ndigits, edigits;
b_decode(m, MAX_DIGITS, input, inlen);
bdump(m);
ndigits = b_digits(pub->n, MAX_DIGITS);
edigits = b_digits(pub->e, MAX_DIGITS);
if (b_cmp(m, pub->n, ndigits) >= 0)
return -1;
b_mod_exp(c, m, pub->e, edigits, pub->n, ndigits);
bdump(c);
b_encode(output, (pub->bits + 7) / 8, c, ndigits);
return (pub->bits + 7) / 8;
}
/* private key operation */
static int do_rsa_private(u8 *output, const u8 *input, const int inlen, const RSA *priv)
{
DIGIT c[MAX_DIGITS], cP[MAX_DIGITS], cQ[MAX_DIGITS],
mP[MAX_DIGITS], mQ[MAX_DIGITS], t[MAX_DIGITS*2];
int cdigits, ndigits, pdigits;
b_decode(c, MAX_DIGITS, input, inlen);
cdigits = b_digits(c, MAX_DIGITS);
ndigits = b_digits(priv->pub.n, MAX_DIGITS);
pdigits = b_digits(priv->p, MAX_DIGITS);
if (b_cmp(c, priv->pub.n, ndigits) >= 0)
return -1;
/* mP = cP^dp mod p, mQ = cQ^dQ mod q */
b_div(t, cP, c, cdigits, priv->p, pdigits);
b_div(t, cQ, c, cdigits, priv->q, pdigits);
b_mod_exp(mP, cP, priv->dP, pdigits, priv->p, pdigits);
b_zero(mQ, ndigits);
b_mod_exp(mQ, cQ, priv->dQ, pdigits, priv->q, pdigits);
/* chinese theorem: m = ((((mP - mQ) mod p) * qInv) mod p) * q + mQ. */
if (b_cmp(mP, mQ, pdigits) >= 0)
b_sub(t, mP, mQ, pdigits);
else {
b_sub(t, mQ, mP, pdigits);
b_sub(t, priv->p, t, pdigits);
}
b_mod_mul(t, t, priv->qInv, priv->p, pdigits);
b_mul(t, t, priv->q, pdigits);
b_add(t, t, mQ, ndigits);
bdump(t);
b_encode(output, (priv->pub.bits + 7) / 8, t, ndigits);
return (priv->pub.bits + 7) / 8;
}
/* point is to prevent zero bytes .. */
static inline u8 nzrnd()
{
u8 b = 0;
while (!b) rand_bytes(&b, 1);
return b;
}
/*************************************************
* visible RSA functions
*************************************************/
/* decode from pkcs stream (b) to machine order (a) */
void b_decode(DIGIT *a, unsigned digits, const u8 *b, const int len)
{
DIGIT t;
int j;
unsigned int i, u;
for (i = 0, j = len - 1; i < digits && j >= 0; i++) {
t = 0;
for (u = 0; j >= 0 && u < DIGIT_BITS; j--, u += 8)
t |= ((DIGIT) b[j]) << u;
a[i] = t;
}
for (; i < digits; i++)
a[i] = 0;
}
/* decode from machine order (b) to pkcs order (a) */
void b_encode(u8 *a, int len, const DIGIT *b, const unsigned digits)
{
DIGIT t;
int j;
unsigned int i, u;
for (i = 0, j = len - 1; i < digits && j >= 0; i++) {
t = b[i];
for (u = 0; j >= 0 && u < DIGIT_BITS; j--, u += 8)
a[j] = (unsigned char) (t >> u);
}
for (; j >= 0; j--)
a[j] = 0;
}
int rsa_sign(u8 *output, const u8 *input, const int inlen, const RSA *key)
{
u8 pkcs[MAX_RSA_BYTES];
int i, mlen = (key->pub.bits + 7) / 8;
if (inlen + 11 > mlen)
return -1;
pkcs[0] = 0; pkcs[1] = 1;
for (i = 2; i < mlen - inlen - 1; i++)
pkcs[i] = 0xff;
pkcs[i++] = 0;
memcpy(pkcs + i, input, inlen);
return do_rsa_private(output, pkcs, mlen, key);
}
int rsa_encrypt(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *key)
{
u8 pkcs[MAX_RSA_BYTES];
int i, mlen = (key->bits + 7) / 8;
if (inlen + 11 > mlen)
return -1;
pkcs[0] = 0; pkcs[1] = 2;
for (i = 2; i < mlen - inlen - 1; i++)
pkcs[i] = nzrnd();
pkcs[i++] = 0;
memcpy(pkcs + i, input, inlen);
return do_rsa_public(output, pkcs, mlen, key);
}
int rsa_verify(u8 *output, const u8 *input, const int inlen, const RSA_PUBLIC *key)
{
u8 pkcs[MAX_RSA_BYTES];
int pkcslen;
int ret, i, mlen = (key->bits + 7) / 8;
if (inlen > mlen)
return -1;
if ((pkcslen = do_rsa_public(pkcs, input, inlen, key)) < 0)
return pkcslen;
if (pkcslen != mlen)
return -2;
if (pkcs[0] || pkcs[1] != 1)
return -3;
for (i = 2; i < mlen-1; i++)
if (!pkcs[i])
break;
ret = mlen - i - 1;
if (ret + 11 > mlen)
return -3;
memcpy(output, pkcs + i + 1, ret);
return ret;
}
int rsa_decrypt(u8 *output, const u8 *input, const int inlen, const RSA *key)
{
u8 pkcs[MAX_RSA_BYTES];
int pkcslen;
int ret, i, mlen = (key->pub.bits + 7) / 8;
if (inlen > mlen)
return -1;
if ((pkcslen = do_rsa_private(pkcs, input, inlen, key)) < 0)
return -1;
if (pkcslen != mlen)
return -2;
if (pkcs[0] || pkcs[1] != 2)
return -3;
/* find terminator */
for (i = 2; i < mlen-1; i++)
if (!pkcs[i])
break;
ret = mlen - i - 1;
if (ret + 11 > mlen)
return -3;
memcpy(output, pkcs + i + 1, ret);
return ret;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment