Skip to content

Instantly share code, notes, and snippets.

@elichai

elichai/diff.rs Secret

Last active November 11, 2020 11:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elichai/dbbb0360dd0a6a6180a2e749588a05f9 to your computer and use it in GitHub Desktop.
Save elichai/dbbb0360dd0a6a6180a2e749588a05f9 to your computer and use it in GitHub Desktop.
replace thread_local with local_rng
diff --git a/src/ecdh.rs b/src/ecdh.rs
index 2174399..7de89e3 100644
--- a/src/ecdh.rs
+++ b/src/ecdh.rs
@@ -164,15 +164,15 @@ impl SharedSecret {
#[cfg(test)]
mod tests {
- use rand::thread_rng;
+ use test_rng::local_rng;
use super::SharedSecret;
use super::super::Secp256k1;
#[test]
fn ecdh() {
let s = Secp256k1::signing_only();
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng());
- let (sk2, pk2) = s.generate_keypair(&mut thread_rng());
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng());
+ let (sk2, pk2) = s.generate_keypair(&mut local_rng());
let sec1 = SharedSecret::new(&pk1, &sk2);
let sec2 = SharedSecret::new(&pk2, &sk1);
@@ -184,8 +184,8 @@ mod tests {
#[test]
fn ecdh_with_hash() {
let s = Secp256k1::signing_only();
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng());
- let (sk2, pk2) = s.generate_keypair(&mut thread_rng());
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng());
+ let (sk2, pk2) = s.generate_keypair(&mut local_rng());
let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into());
let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into());
@@ -197,7 +197,7 @@ mod tests {
#[test]
fn ecdh_with_hash_callback() {
let s = Secp256k1::signing_only();
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng());
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng());
let expect_result: [u8; 64] = [123; 64];
let mut x_out = [0u8; 32];
let mut y_out = [0u8; 32];
@@ -229,7 +229,7 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod benches {
- use rand::thread_rng;
+ use test_rng::local_rng;
use test::{Bencher, black_box};
use super::SharedSecret;
@@ -238,7 +238,7 @@ mod benches {
#[bench]
pub fn bench_ecdh(bh: &mut Bencher) {
let s = Secp256k1::signing_only();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
bh.iter( || {
let res = SharedSecret::new(&pk, &sk);
diff --git a/src/key.rs b/src/key.rs
index 8d65185..0c0310c 100644
--- a/src/key.rs
+++ b/src/key.rs
@@ -477,8 +477,9 @@ mod test {
use super::super::Error::{InvalidPublicKey, InvalidSecretKey};
use super::{PublicKey, SecretKey};
use super::super::constants;
+ use test_rng::local_rng;
- use rand::{Error, ErrorKind, RngCore, thread_rng};
+ use rand::{Error, ErrorKind, RngCore};
use rand_core::impls;
use std::iter;
use std::str::FromStr;
@@ -516,7 +517,7 @@ mod test {
fn keypair_slice_round_trip() {
let s = Secp256k1::new();
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng());
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng());
assert_eq!(SecretKey::from_slice(&sk1[..]), Ok(sk1));
assert_eq!(PublicKey::from_slice(&pk1.serialize()[..]), Ok(pk1));
assert_eq!(PublicKey::from_slice(&pk1.serialize_uncompressed()[..]), Ok(pk1));
@@ -747,8 +748,8 @@ mod test {
fn test_addition() {
let s = Secp256k1::new();
- let (mut sk1, mut pk1) = s.generate_keypair(&mut thread_rng());
- let (mut sk2, mut pk2) = s.generate_keypair(&mut thread_rng());
+ let (mut sk1, mut pk1) = s.generate_keypair(&mut local_rng());
+ let (mut sk2, mut pk2) = s.generate_keypair(&mut local_rng());
assert_eq!(PublicKey::from_secret_key(&s, &sk1), pk1);
assert!(sk1.add_assign(&sk2[..]).is_ok());
@@ -765,8 +766,8 @@ mod test {
fn test_multiplication() {
let s = Secp256k1::new();
- let (mut sk1, mut pk1) = s.generate_keypair(&mut thread_rng());
- let (mut sk2, mut pk2) = s.generate_keypair(&mut thread_rng());
+ let (mut sk1, mut pk1) = s.generate_keypair(&mut local_rng());
+ let (mut sk2, mut pk2) = s.generate_keypair(&mut local_rng());
assert_eq!(PublicKey::from_secret_key(&s, &sk1), pk1);
assert!(sk1.mul_assign(&sk2[..]).is_ok());
@@ -783,7 +784,7 @@ mod test {
fn test_negation() {
let s = Secp256k1::new();
- let (mut sk, mut pk) = s.generate_keypair(&mut thread_rng());
+ let (mut sk, mut pk) = s.generate_keypair(&mut local_rng());
let original_sk = sk;
let original_pk = pk;
@@ -816,7 +817,7 @@ mod test {
let mut set = HashSet::new();
const COUNT : usize = 1024;
let count = (0..COUNT).map(|_| {
- let (_, pk) = s.generate_keypair(&mut thread_rng());
+ let (_, pk) = s.generate_keypair(&mut local_rng());
let hash = hash(&pk);
assert!(!set.contains(&hash));
set.insert(hash);
diff --git a/src/lib.rs b/src/lib.rs
index 10a4ed3..d60fe8f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -145,6 +145,9 @@ pub mod key;
#[cfg(feature = "recovery")]
pub mod recovery;
+#[cfg(test)]
+mod test_rng;
+
pub use key::SecretKey;
pub use key::PublicKey;
pub use context::*;
@@ -752,10 +755,11 @@ fn from_hex(hex: &str, target: &mut [u8]) -> Result<usize, ()> {
#[cfg(test)]
mod tests {
- use rand::{RngCore, thread_rng};
+ use rand::RngCore;
use std::str::FromStr;
use std::marker::PhantomData;
+ use test_rng::local_rng;
use key::{SecretKey, PublicKey};
use super::from_hex;
use super::constants;
@@ -784,7 +788,7 @@ mod tests {
let sign: Secp256k1<SignOnlyPreallocated> = Secp256k1{ctx: ctx_sign, phantom: PhantomData, buf};
let vrfy: Secp256k1<VerifyOnlyPreallocated> = Secp256k1{ctx: ctx_vrfy, phantom: PhantomData, buf};
- let (sk, pk) = full.generate_keypair(&mut thread_rng());
+ let (sk, pk) = full.generate_keypair(&mut local_rng());
let msg = Message::from_slice(&[2u8; 32]).unwrap();
// Try signing
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk));
@@ -811,7 +815,7 @@ mod tests {
let sign = unsafe {Secp256k1::from_raw_signining_only(ctx_sign.ctx)};
let vrfy = unsafe {Secp256k1::from_raw_verification_only(ctx_vrfy.ctx)};
- let (sk, pk) = full.generate_keypair(&mut thread_rng());
+ let (sk, pk) = full.generate_keypair(&mut local_rng());
let msg = Message::from_slice(&[2u8; 32]).unwrap();
// Try signing
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk));
@@ -830,7 +834,7 @@ mod tests {
fn test_panic_raw_ctx() {
let ctx_vrfy = Secp256k1::verification_only();
let raw_ctx_verify_as_full = unsafe {Secp256k1::from_raw_all(ctx_vrfy.ctx)};
- let (sk, _) = raw_ctx_verify_as_full.generate_keypair(&mut thread_rng());
+ let (sk, _) = raw_ctx_verify_as_full.generate_keypair(&mut local_rng());
let msg = Message::from_slice(&[2u8; 32]).unwrap();
// Try signing
raw_ctx_verify_as_full.sign(&msg, &sk);
@@ -849,7 +853,7 @@ mod tests {
// drop(buf_vfy); // The buffer can't get dropped before the context.
// println!("{:?}", buf_ful[5]); // Can't even read the data thanks to the borrow checker.
- let (sk, pk) = full.generate_keypair(&mut thread_rng());
+ let (sk, pk) = full.generate_keypair(&mut local_rng());
let msg = Message::from_slice(&[2u8; 32]).unwrap();
// Try signing
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk));
@@ -867,11 +871,11 @@ mod tests {
let full = Secp256k1::new();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
// Try key generation
- let (sk, pk) = full.generate_keypair(&mut thread_rng());
+ let (sk, pk) = full.generate_keypair(&mut local_rng());
// Try signing
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk));
@@ -892,14 +896,14 @@ mod tests {
#[test]
fn signature_serialize_roundtrip() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let mut msg = [0; 32];
for _ in 0..100 {
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, _) = s.generate_keypair(&mut thread_rng());
+ let (sk, _) = s.generate_keypair(&mut local_rng());
let sig1 = s.sign(&msg, &sk);
let der = sig1.serialize_der();
let sig2 = Signature::from_der(&der[..]).unwrap();
@@ -978,14 +982,14 @@ mod tests {
#[test]
fn sign_and_verify() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let mut msg = [0; 32];
for _ in 0..100 {
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
let sig = s.sign(&msg, &sk);
assert_eq!(s.verify(&msg, &sig, &pk), Ok(()));
}
@@ -994,7 +998,7 @@ mod tests {
#[test]
fn sign_and_verify_extreme() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
// Wild keys: 1, CURVE_ORDER - 1
// Wild msgs: 1, CURVE_ORDER - 1
@@ -1023,18 +1027,18 @@ mod tests {
#[test]
fn sign_and_verify_fail() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
let sig = s.sign(&msg, &sk);
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
assert_eq!(s.verify(&msg, &sig, &pk), Err(IncorrectSignature));
}
@@ -1174,7 +1178,7 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod benches {
- use rand::{thread_rng, RngCore};
+ use rand::{local_rng, RngCore};
use test::{Bencher, black_box};
use super::{Secp256k1, Message};
@@ -1218,9 +1222,9 @@ mod benches {
pub fn bench_sign(bh: &mut Bencher) {
let s = Secp256k1::new();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, _) = s.generate_keypair(&mut thread_rng());
+ let (sk, _) = s.generate_keypair(&mut local_rng());
bh.iter(|| {
let sig = s.sign(&msg, &sk);
@@ -1232,9 +1236,9 @@ mod benches {
pub fn bench_verify(bh: &mut Bencher) {
let s = Secp256k1::new();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
let sig = s.sign(&msg, &sk);
bh.iter(|| {
diff --git a/src/recovery.rs b/src/recovery.rs
index 0152994..eb91edc 100644
--- a/src/recovery.rs
+++ b/src/recovery.rs
@@ -193,7 +193,7 @@ impl<C: Verification> Secp256k1<C> {
#[cfg(test)]
mod tests {
- use rand::{RngCore, thread_rng};
+ use rand::{RngCore, local_rng};
use key::SecretKey;
use super::{RecoveryId, RecoverableSignature};
@@ -207,11 +207,11 @@ mod tests {
let full = Secp256k1::new();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
// Try key generation
- let (sk, pk) = full.generate_keypair(&mut thread_rng());
+ let (sk, pk) = full.generate_keypair(&mut local_rng());
// Try signing
assert_eq!(sign.sign_recoverable(&msg, &sk), full.sign_recoverable(&msg, &sk));
@@ -235,7 +235,7 @@ mod tests {
#[test]
fn sign() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let one = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
@@ -258,19 +258,19 @@ mod tests {
#[test]
fn sign_and_verify_fail() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
let sigr = s.sign_recoverable(&msg, &sk);
let sig = sigr.to_standard();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
assert_eq!(s.verify(&msg, &sig, &pk), Err(IncorrectSignature));
@@ -281,13 +281,13 @@ mod tests {
#[test]
fn sign_with_recovery() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, pk) = s.generate_keypair(&mut thread_rng());
+ let (sk, pk) = s.generate_keypair(&mut local_rng());
let sig = s.sign_recoverable(&msg, &sk);
@@ -297,7 +297,7 @@ mod tests {
#[test]
fn bad_recovery() {
let mut s = Secp256k1::new();
- s.randomize(&mut thread_rng());
+ s.randomize(&mut local_rng());
let msg = Message::from_slice(&[0x55; 32]).unwrap();
@@ -363,7 +363,7 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod benches {
- use rand::{thread_rng, RngCore};
+ use rand::{local_rng, RngCore};
use test::{Bencher, black_box};
use super::{Message, Secp256k1};
@@ -371,9 +371,9 @@ mod benches {
pub fn bench_recover(bh: &mut Bencher) {
let s = Secp256k1::new();
let mut msg = [0u8; 32];
- thread_rng().fill_bytes(&mut msg);
+ local_rng().fill_bytes(&mut msg);
let msg = Message::from_slice(&msg).unwrap();
- let (sk, _) = s.generate_keypair(&mut thread_rng());
+ let (sk, _) = s.generate_keypair(&mut local_rng());
let sig = s.sign_recoverable(&msg, &sk);
bh.iter(|| {
diff --git a/src/test_rng.rs b/src/test_rng.rs
new file mode 100644
index 0000000..75fdb65
--- /dev/null
+++ b/src/test_rng.rs
@@ -0,0 +1,107 @@
+use rand_core::{impls, Error, RngCore};
+use std::cell::UnsafeCell;
+use std::num::Wrapping;
+use std::time;
+
+/// A shim that points to a global `Pcg32` instance. isn't safe for multi-threading.
+///
+/// This struct is created by [`local_rng()`](#local_rng)
+pub struct ThreadFastRng(*mut Pcg32);
+
+pub fn local_rng() -> ThreadFastRng {
+ thread_local! {
+ pub static THREAD_FAST_RNG: UnsafeCell<Pcg32> = UnsafeCell::new(Pcg32::new());
+ }
+ let ptr = THREAD_FAST_RNG.with(|r| r.get());
+ ThreadFastRng(ptr)
+}
+
+pub struct Pcg32 {
+ state: Wrapping<u64>,
+ inc: Wrapping<u64>,
+}
+
+impl Pcg32 {
+ const PCG_DEFAULT_MULTIPLIER_64: Wrapping<u64> = Wrapping(6_364_136_223_846_793_005);
+ /// Creates a new instance of `Pcg32` seeded with the system time.
+ ///
+ /// # Examples
+ /// ```rust
+ /// let mut rng = Pcg32::new();
+ /// let random_u8 = rng.get_u8();
+ /// let arr: [u8; 32] = rng.gen();
+ /// ```
+ ///
+ pub fn new() -> Self {
+ let now = time::SystemTime::now();
+ let unix = now
+ .duration_since(time::UNIX_EPOCH)
+ .expect("now should be before unix epoch");
+ let sec = unix.as_secs();
+ let subsec = u64::from(unix.subsec_nanos());
+ Self::seed(sec, subsec)
+ }
+
+ /// A function to manually seed the Rng.
+ /// Ideally both the `seed` and the `seq` should be random numbers.
+ /// the `seed` represents the starting state of the algorithm,
+ /// and the `seq` represents a constant random sequence that will be used to increment and re-randomize the state.
+ pub fn seed(seed: u64, seq: u64) -> Self {
+ let init_inc = Wrapping((seq << 1) | 1);
+ let init_state = Wrapping(seed) + init_inc;
+ let mut rng = Pcg32 {
+ state: init_state,
+ inc: init_inc,
+ };
+ rng.state = rng.state * Self::PCG_DEFAULT_MULTIPLIER_64 + rng.inc;
+ rng
+ }
+
+ fn gen_u32(&mut self) -> u32 {
+ let old_state = self.state;
+ self.state = self.state * Self::PCG_DEFAULT_MULTIPLIER_64 + self.inc;
+
+ let xorshift = (((old_state >> 18) ^ old_state) >> 27).0 as u32;
+ let rot = (old_state >> 59).0 as i32;
+ (xorshift >> rot) | (xorshift << ((-rot) & 31))
+ }
+}
+
+// SAFETY:
+// 1. The pointer is always pointing to a thread local
+// 2. It is always used a mutable reference (doesn't invalidate the permissions of the pointer).
+impl RngCore for ThreadFastRng {
+ fn next_u32(&mut self) -> u32 {
+ unsafe { &mut *self.0 }.next_u32()
+ }
+
+ fn next_u64(&mut self) -> u64 {
+ unsafe { &mut *self.0 }.next_u64()
+ }
+
+ fn fill_bytes(&mut self, dest: &mut [u8]) {
+ unsafe { &mut *self.0 }.fill_bytes(dest)
+ }
+
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
+ unsafe { &mut *self.0 }.try_fill_bytes(dest)
+ }
+}
+
+impl RngCore for Pcg32 {
+ fn next_u32(&mut self) -> u32 {
+ self.gen_u32()
+ }
+
+ fn next_u64(&mut self) -> u64 {
+ ((self.next_u32() as u64) << 32) + (self.next_u32() as u64)
+ }
+
+ fn fill_bytes(&mut self, dest: &mut [u8]) {
+ impls::fill_bytes_via_next(self, dest)
+ }
+
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
+ Ok(self.fill_bytes(dest))
+ }
+}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment