-
-
Save elichai/dbbb0360dd0a6a6180a2e749588a05f9 to your computer and use it in GitHub Desktop.
replace thread_local with local_rng
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/ecdh.rs b/src/ecdh.rs | |
index 2174399..7de89e3 100644 | |
--- a/src/ecdh.rs | |
+++ b/src/ecdh.rs | |
@@ -164,15 +164,15 @@ impl SharedSecret { | |
#[cfg(test)] | |
mod tests { | |
- use rand::thread_rng; | |
+ use test_rng::local_rng; | |
use super::SharedSecret; | |
use super::super::Secp256k1; | |
#[test] | |
fn ecdh() { | |
let s = Secp256k1::signing_only(); | |
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); | |
- let (sk2, pk2) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng()); | |
+ let (sk2, pk2) = s.generate_keypair(&mut local_rng()); | |
let sec1 = SharedSecret::new(&pk1, &sk2); | |
let sec2 = SharedSecret::new(&pk2, &sk1); | |
@@ -184,8 +184,8 @@ mod tests { | |
#[test] | |
fn ecdh_with_hash() { | |
let s = Secp256k1::signing_only(); | |
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); | |
- let (sk2, pk2) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng()); | |
+ let (sk2, pk2) = s.generate_keypair(&mut local_rng()); | |
let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into()); | |
let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into()); | |
@@ -197,7 +197,7 @@ mod tests { | |
#[test] | |
fn ecdh_with_hash_callback() { | |
let s = Secp256k1::signing_only(); | |
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng()); | |
let expect_result: [u8; 64] = [123; 64]; | |
let mut x_out = [0u8; 32]; | |
let mut y_out = [0u8; 32]; | |
@@ -229,7 +229,7 @@ mod tests { | |
#[cfg(all(test, feature = "unstable"))] | |
mod benches { | |
- use rand::thread_rng; | |
+ use test_rng::local_rng; | |
use test::{Bencher, black_box}; | |
use super::SharedSecret; | |
@@ -238,7 +238,7 @@ mod benches { | |
#[bench] | |
pub fn bench_ecdh(bh: &mut Bencher) { | |
let s = Secp256k1::signing_only(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
bh.iter( || { | |
let res = SharedSecret::new(&pk, &sk); | |
diff --git a/src/key.rs b/src/key.rs | |
index 8d65185..0c0310c 100644 | |
--- a/src/key.rs | |
+++ b/src/key.rs | |
@@ -477,8 +477,9 @@ mod test { | |
use super::super::Error::{InvalidPublicKey, InvalidSecretKey}; | |
use super::{PublicKey, SecretKey}; | |
use super::super::constants; | |
+ use test_rng::local_rng; | |
- use rand::{Error, ErrorKind, RngCore, thread_rng}; | |
+ use rand::{Error, ErrorKind, RngCore}; | |
use rand_core::impls; | |
use std::iter; | |
use std::str::FromStr; | |
@@ -516,7 +517,7 @@ mod test { | |
fn keypair_slice_round_trip() { | |
let s = Secp256k1::new(); | |
- let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk1, pk1) = s.generate_keypair(&mut local_rng()); | |
assert_eq!(SecretKey::from_slice(&sk1[..]), Ok(sk1)); | |
assert_eq!(PublicKey::from_slice(&pk1.serialize()[..]), Ok(pk1)); | |
assert_eq!(PublicKey::from_slice(&pk1.serialize_uncompressed()[..]), Ok(pk1)); | |
@@ -747,8 +748,8 @@ mod test { | |
fn test_addition() { | |
let s = Secp256k1::new(); | |
- let (mut sk1, mut pk1) = s.generate_keypair(&mut thread_rng()); | |
- let (mut sk2, mut pk2) = s.generate_keypair(&mut thread_rng()); | |
+ let (mut sk1, mut pk1) = s.generate_keypair(&mut local_rng()); | |
+ let (mut sk2, mut pk2) = s.generate_keypair(&mut local_rng()); | |
assert_eq!(PublicKey::from_secret_key(&s, &sk1), pk1); | |
assert!(sk1.add_assign(&sk2[..]).is_ok()); | |
@@ -765,8 +766,8 @@ mod test { | |
fn test_multiplication() { | |
let s = Secp256k1::new(); | |
- let (mut sk1, mut pk1) = s.generate_keypair(&mut thread_rng()); | |
- let (mut sk2, mut pk2) = s.generate_keypair(&mut thread_rng()); | |
+ let (mut sk1, mut pk1) = s.generate_keypair(&mut local_rng()); | |
+ let (mut sk2, mut pk2) = s.generate_keypair(&mut local_rng()); | |
assert_eq!(PublicKey::from_secret_key(&s, &sk1), pk1); | |
assert!(sk1.mul_assign(&sk2[..]).is_ok()); | |
@@ -783,7 +784,7 @@ mod test { | |
fn test_negation() { | |
let s = Secp256k1::new(); | |
- let (mut sk, mut pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (mut sk, mut pk) = s.generate_keypair(&mut local_rng()); | |
let original_sk = sk; | |
let original_pk = pk; | |
@@ -816,7 +817,7 @@ mod test { | |
let mut set = HashSet::new(); | |
const COUNT : usize = 1024; | |
let count = (0..COUNT).map(|_| { | |
- let (_, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (_, pk) = s.generate_keypair(&mut local_rng()); | |
let hash = hash(&pk); | |
assert!(!set.contains(&hash)); | |
set.insert(hash); | |
diff --git a/src/lib.rs b/src/lib.rs | |
index 10a4ed3..d60fe8f 100644 | |
--- a/src/lib.rs | |
+++ b/src/lib.rs | |
@@ -145,6 +145,9 @@ pub mod key; | |
#[cfg(feature = "recovery")] | |
pub mod recovery; | |
+#[cfg(test)] | |
+mod test_rng; | |
+ | |
pub use key::SecretKey; | |
pub use key::PublicKey; | |
pub use context::*; | |
@@ -752,10 +755,11 @@ fn from_hex(hex: &str, target: &mut [u8]) -> Result<usize, ()> { | |
#[cfg(test)] | |
mod tests { | |
- use rand::{RngCore, thread_rng}; | |
+ use rand::RngCore; | |
use std::str::FromStr; | |
use std::marker::PhantomData; | |
+ use test_rng::local_rng; | |
use key::{SecretKey, PublicKey}; | |
use super::from_hex; | |
use super::constants; | |
@@ -784,7 +788,7 @@ mod tests { | |
let sign: Secp256k1<SignOnlyPreallocated> = Secp256k1{ctx: ctx_sign, phantom: PhantomData, buf}; | |
let vrfy: Secp256k1<VerifyOnlyPreallocated> = Secp256k1{ctx: ctx_vrfy, phantom: PhantomData, buf}; | |
- let (sk, pk) = full.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = full.generate_keypair(&mut local_rng()); | |
let msg = Message::from_slice(&[2u8; 32]).unwrap(); | |
// Try signing | |
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk)); | |
@@ -811,7 +815,7 @@ mod tests { | |
let sign = unsafe {Secp256k1::from_raw_signining_only(ctx_sign.ctx)}; | |
let vrfy = unsafe {Secp256k1::from_raw_verification_only(ctx_vrfy.ctx)}; | |
- let (sk, pk) = full.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = full.generate_keypair(&mut local_rng()); | |
let msg = Message::from_slice(&[2u8; 32]).unwrap(); | |
// Try signing | |
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk)); | |
@@ -830,7 +834,7 @@ mod tests { | |
fn test_panic_raw_ctx() { | |
let ctx_vrfy = Secp256k1::verification_only(); | |
let raw_ctx_verify_as_full = unsafe {Secp256k1::from_raw_all(ctx_vrfy.ctx)}; | |
- let (sk, _) = raw_ctx_verify_as_full.generate_keypair(&mut thread_rng()); | |
+ let (sk, _) = raw_ctx_verify_as_full.generate_keypair(&mut local_rng()); | |
let msg = Message::from_slice(&[2u8; 32]).unwrap(); | |
// Try signing | |
raw_ctx_verify_as_full.sign(&msg, &sk); | |
@@ -849,7 +853,7 @@ mod tests { | |
// drop(buf_vfy); // The buffer can't get dropped before the context. | |
// println!("{:?}", buf_ful[5]); // Can't even read the data thanks to the borrow checker. | |
- let (sk, pk) = full.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = full.generate_keypair(&mut local_rng()); | |
let msg = Message::from_slice(&[2u8; 32]).unwrap(); | |
// Try signing | |
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk)); | |
@@ -867,11 +871,11 @@ mod tests { | |
let full = Secp256k1::new(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
// Try key generation | |
- let (sk, pk) = full.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = full.generate_keypair(&mut local_rng()); | |
// Try signing | |
assert_eq!(sign.sign(&msg, &sk), full.sign(&msg, &sk)); | |
@@ -892,14 +896,14 @@ mod tests { | |
#[test] | |
fn signature_serialize_roundtrip() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let mut msg = [0; 32]; | |
for _ in 0..100 { | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, _) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, _) = s.generate_keypair(&mut local_rng()); | |
let sig1 = s.sign(&msg, &sk); | |
let der = sig1.serialize_der(); | |
let sig2 = Signature::from_der(&der[..]).unwrap(); | |
@@ -978,14 +982,14 @@ mod tests { | |
#[test] | |
fn sign_and_verify() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let mut msg = [0; 32]; | |
for _ in 0..100 { | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
let sig = s.sign(&msg, &sk); | |
assert_eq!(s.verify(&msg, &sig, &pk), Ok(())); | |
} | |
@@ -994,7 +998,7 @@ mod tests { | |
#[test] | |
fn sign_and_verify_extreme() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
// Wild keys: 1, CURVE_ORDER - 1 | |
// Wild msgs: 1, CURVE_ORDER - 1 | |
@@ -1023,18 +1027,18 @@ mod tests { | |
#[test] | |
fn sign_and_verify_fail() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
let sig = s.sign(&msg, &sk); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
assert_eq!(s.verify(&msg, &sig, &pk), Err(IncorrectSignature)); | |
} | |
@@ -1174,7 +1178,7 @@ mod tests { | |
#[cfg(all(test, feature = "unstable"))] | |
mod benches { | |
- use rand::{thread_rng, RngCore}; | |
+ use rand::{local_rng, RngCore}; | |
use test::{Bencher, black_box}; | |
use super::{Secp256k1, Message}; | |
@@ -1218,9 +1222,9 @@ mod benches { | |
pub fn bench_sign(bh: &mut Bencher) { | |
let s = Secp256k1::new(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, _) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, _) = s.generate_keypair(&mut local_rng()); | |
bh.iter(|| { | |
let sig = s.sign(&msg, &sk); | |
@@ -1232,9 +1236,9 @@ mod benches { | |
pub fn bench_verify(bh: &mut Bencher) { | |
let s = Secp256k1::new(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
let sig = s.sign(&msg, &sk); | |
bh.iter(|| { | |
diff --git a/src/recovery.rs b/src/recovery.rs | |
index 0152994..eb91edc 100644 | |
--- a/src/recovery.rs | |
+++ b/src/recovery.rs | |
@@ -193,7 +193,7 @@ impl<C: Verification> Secp256k1<C> { | |
#[cfg(test)] | |
mod tests { | |
- use rand::{RngCore, thread_rng}; | |
+ use rand::{RngCore, local_rng}; | |
use key::SecretKey; | |
use super::{RecoveryId, RecoverableSignature}; | |
@@ -207,11 +207,11 @@ mod tests { | |
let full = Secp256k1::new(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
// Try key generation | |
- let (sk, pk) = full.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = full.generate_keypair(&mut local_rng()); | |
// Try signing | |
assert_eq!(sign.sign_recoverable(&msg, &sk), full.sign_recoverable(&msg, &sk)); | |
@@ -235,7 +235,7 @@ mod tests { | |
#[test] | |
fn sign() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let one = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; | |
@@ -258,19 +258,19 @@ mod tests { | |
#[test] | |
fn sign_and_verify_fail() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
let sigr = s.sign_recoverable(&msg, &sk); | |
let sig = sigr.to_standard(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
assert_eq!(s.verify(&msg, &sig, &pk), Err(IncorrectSignature)); | |
@@ -281,13 +281,13 @@ mod tests { | |
#[test] | |
fn sign_with_recovery() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, pk) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, pk) = s.generate_keypair(&mut local_rng()); | |
let sig = s.sign_recoverable(&msg, &sk); | |
@@ -297,7 +297,7 @@ mod tests { | |
#[test] | |
fn bad_recovery() { | |
let mut s = Secp256k1::new(); | |
- s.randomize(&mut thread_rng()); | |
+ s.randomize(&mut local_rng()); | |
let msg = Message::from_slice(&[0x55; 32]).unwrap(); | |
@@ -363,7 +363,7 @@ mod tests { | |
#[cfg(all(test, feature = "unstable"))] | |
mod benches { | |
- use rand::{thread_rng, RngCore}; | |
+ use rand::{local_rng, RngCore}; | |
use test::{Bencher, black_box}; | |
use super::{Message, Secp256k1}; | |
@@ -371,9 +371,9 @@ mod benches { | |
pub fn bench_recover(bh: &mut Bencher) { | |
let s = Secp256k1::new(); | |
let mut msg = [0u8; 32]; | |
- thread_rng().fill_bytes(&mut msg); | |
+ local_rng().fill_bytes(&mut msg); | |
let msg = Message::from_slice(&msg).unwrap(); | |
- let (sk, _) = s.generate_keypair(&mut thread_rng()); | |
+ let (sk, _) = s.generate_keypair(&mut local_rng()); | |
let sig = s.sign_recoverable(&msg, &sk); | |
bh.iter(|| { | |
diff --git a/src/test_rng.rs b/src/test_rng.rs | |
new file mode 100644 | |
index 0000000..75fdb65 | |
--- /dev/null | |
+++ b/src/test_rng.rs | |
@@ -0,0 +1,107 @@ | |
+use rand_core::{impls, Error, RngCore}; | |
+use std::cell::UnsafeCell; | |
+use std::num::Wrapping; | |
+use std::time; | |
+ | |
+/// A shim that points to a global `Pcg32` instance. isn't safe for multi-threading. | |
+/// | |
+/// This struct is created by [`local_rng()`](#local_rng) | |
+pub struct ThreadFastRng(*mut Pcg32); | |
+ | |
+pub fn local_rng() -> ThreadFastRng { | |
+ thread_local! { | |
+ pub static THREAD_FAST_RNG: UnsafeCell<Pcg32> = UnsafeCell::new(Pcg32::new()); | |
+ } | |
+ let ptr = THREAD_FAST_RNG.with(|r| r.get()); | |
+ ThreadFastRng(ptr) | |
+} | |
+ | |
+pub struct Pcg32 { | |
+ state: Wrapping<u64>, | |
+ inc: Wrapping<u64>, | |
+} | |
+ | |
+impl Pcg32 { | |
+ const PCG_DEFAULT_MULTIPLIER_64: Wrapping<u64> = Wrapping(6_364_136_223_846_793_005); | |
+ /// Creates a new instance of `Pcg32` seeded with the system time. | |
+ /// | |
+ /// # Examples | |
+ /// ```rust | |
+ /// let mut rng = Pcg32::new(); | |
+ /// let random_u8 = rng.get_u8(); | |
+ /// let arr: [u8; 32] = rng.gen(); | |
+ /// ``` | |
+ /// | |
+ pub fn new() -> Self { | |
+ let now = time::SystemTime::now(); | |
+ let unix = now | |
+ .duration_since(time::UNIX_EPOCH) | |
+ .expect("now should be before unix epoch"); | |
+ let sec = unix.as_secs(); | |
+ let subsec = u64::from(unix.subsec_nanos()); | |
+ Self::seed(sec, subsec) | |
+ } | |
+ | |
+ /// A function to manually seed the Rng. | |
+ /// Ideally both the `seed` and the `seq` should be random numbers. | |
+ /// the `seed` represents the starting state of the algorithm, | |
+ /// and the `seq` represents a constant random sequence that will be used to increment and re-randomize the state. | |
+ pub fn seed(seed: u64, seq: u64) -> Self { | |
+ let init_inc = Wrapping((seq << 1) | 1); | |
+ let init_state = Wrapping(seed) + init_inc; | |
+ let mut rng = Pcg32 { | |
+ state: init_state, | |
+ inc: init_inc, | |
+ }; | |
+ rng.state = rng.state * Self::PCG_DEFAULT_MULTIPLIER_64 + rng.inc; | |
+ rng | |
+ } | |
+ | |
+ fn gen_u32(&mut self) -> u32 { | |
+ let old_state = self.state; | |
+ self.state = self.state * Self::PCG_DEFAULT_MULTIPLIER_64 + self.inc; | |
+ | |
+ let xorshift = (((old_state >> 18) ^ old_state) >> 27).0 as u32; | |
+ let rot = (old_state >> 59).0 as i32; | |
+ (xorshift >> rot) | (xorshift << ((-rot) & 31)) | |
+ } | |
+} | |
+ | |
+// SAFETY: | |
+// 1. The pointer is always pointing to a thread local | |
+// 2. It is always used a mutable reference (doesn't invalidate the permissions of the pointer). | |
+impl RngCore for ThreadFastRng { | |
+ fn next_u32(&mut self) -> u32 { | |
+ unsafe { &mut *self.0 }.next_u32() | |
+ } | |
+ | |
+ fn next_u64(&mut self) -> u64 { | |
+ unsafe { &mut *self.0 }.next_u64() | |
+ } | |
+ | |
+ fn fill_bytes(&mut self, dest: &mut [u8]) { | |
+ unsafe { &mut *self.0 }.fill_bytes(dest) | |
+ } | |
+ | |
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { | |
+ unsafe { &mut *self.0 }.try_fill_bytes(dest) | |
+ } | |
+} | |
+ | |
+impl RngCore for Pcg32 { | |
+ fn next_u32(&mut self) -> u32 { | |
+ self.gen_u32() | |
+ } | |
+ | |
+ fn next_u64(&mut self) -> u64 { | |
+ ((self.next_u32() as u64) << 32) + (self.next_u32() as u64) | |
+ } | |
+ | |
+ fn fill_bytes(&mut self, dest: &mut [u8]) { | |
+ impls::fill_bytes_via_next(self, dest) | |
+ } | |
+ | |
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { | |
+ Ok(self.fill_bytes(dest)) | |
+ } | |
+} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment