-
-
Save maratori/15cf2a41d52ac6bd27dfced7922f0d28 to your computer and use it in GitHub Desktop.
Function to normalize decimal.Decimal in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package util | |
import ( | |
"math" | |
"math/big" | |
"github.com/shopspring/decimal" | |
) | |
var /* const */ ten = big.NewInt(10) | |
var /* const */ zero = decimal.New(0, 0) | |
// Normalize returns [decimal.Decimal] equal to input d, but rescaled. | |
// [decimal.Decimal.Exponent] is changed such a way to remove all trailing zeroes | |
// from decimal representation of [decimal.Decimal.Coefficient]. | |
// As a special case zero decimal has zero [decimal.Decimal.Exponent]. | |
// | |
// Examples | |
// | |
// Normalize(decimal.New(100, 0)) -> decimal.New(1, 2) | |
// Normalize(decimal.New(100, -3)) -> decimal.New(1, -1) | |
// Normalize(decimal.New(123, -2)) -> decimal.New(123, -2) | |
// Normalize(decimal.New(0, 1)) -> decimal.New(0, 0) | |
func Normalize(d decimal.Decimal) decimal.Decimal { | |
if d.IsZero() { | |
return zero | |
} | |
value := d.Coefficient() | |
remainder := new(big.Int) | |
mayValue := new(big.Int) | |
n := int64(0) | |
maxN := math.MaxInt32 - int64(d.Exponent()) | |
for { | |
if n == maxN { | |
break // avoid int32 overflow | |
} | |
mayValue.QuoRem(value, ten, remainder) | |
if remainder.Sign() != 0 { | |
break | |
} | |
value.Set(mayValue) | |
n++ | |
} | |
if n == 0 { | |
return d | |
} | |
return decimal.NewFromBigInt(value, d.Exponent()+int32(n)) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package util_test | |
import ( | |
"math" | |
"math/big" | |
"math/rand" | |
"testing" | |
"time" | |
"github.com/maratori/pairedbrackets/util" | |
"github.com/shopspring/decimal" | |
"github.com/stretchr/testify/assert" | |
) | |
func TestNormalize(t *testing.T) { | |
t.Parallel() | |
for in, out := range map[decimal.Decimal]decimal.Decimal{ | |
decimal.New(0, 0): decimal.New(0, 0), | |
decimal.New(0, 10): decimal.New(0, 0), | |
decimal.New(0, -10): decimal.New(0, 0), | |
decimal.New(10, 0): decimal.New(1, 1), | |
decimal.New(100, 0): decimal.New(1, 2), | |
decimal.New(1000, 0): decimal.New(1, 3), | |
decimal.New(10, 1): decimal.New(1, 2), | |
decimal.New(100, 1): decimal.New(1, 3), | |
decimal.New(1000, 1): decimal.New(1, 4), | |
decimal.New(10, -1): decimal.New(1, 0), | |
decimal.New(100, -1): decimal.New(1, 1), | |
decimal.New(1000, -1): decimal.New(1, 2), | |
decimal.New(-10, 0): decimal.New(-1, 1), | |
decimal.New(-100, 0): decimal.New(-1, 2), | |
decimal.New(-1000, 0): decimal.New(-1, 3), | |
decimal.New(-10, 1): decimal.New(-1, 2), | |
decimal.New(-100, 1): decimal.New(-1, 3), | |
decimal.New(-1000, 1): decimal.New(-1, 4), | |
decimal.New(-10, -1): decimal.New(-1, 0), | |
decimal.New(-100, -1): decimal.New(-1, 1), | |
decimal.New(-1000, -1): decimal.New(-1, 2), | |
decimal.New(1234, 2): decimal.New(1234, 2), | |
decimal.New(1234, -2): decimal.New(1234, -2), | |
decimal.New(1234, math.MaxInt32): decimal.New(1234, math.MaxInt32), | |
decimal.New(12340, math.MaxInt32): decimal.New(12340, math.MaxInt32), | |
decimal.New(123400, math.MaxInt32-1): decimal.New(12340, math.MaxInt32), | |
} { | |
actual := util.Normalize(in) | |
assert.Equal(t, out, actual) | |
} | |
} | |
func FuzzNormalize(f *testing.F) { | |
for _, val := range []int64{0, 1, 10, 100, 1000, 1_000_000_000} { | |
for _, valS := range []int64{1, -1} { | |
for _, exp := range []int32{0, 1, 10, 100, 1000, 1_000_000_000} { | |
for _, expS := range []int32{1, -1} { | |
f.Add(val*valS, exp*expS) | |
} | |
} | |
} | |
} | |
for _, val := range []int64{math.MinInt64, math.MaxInt64, 1, 10, 100, 1000, -1, -10, -100, -1000} { | |
for _, exp := range []int32{math.MinInt32, math.MaxInt32, math.MinInt32 + 1, math.MaxInt32 - 1} { | |
f.Add(val, exp) | |
} | |
} | |
f.Fuzz(func(t *testing.T, val int64, exp int32) { | |
x := decimal.New(val, exp) | |
y := util.Normalize(x) | |
xc := x.Coefficient() | |
xe := x.Exponent() | |
yc := y.Coefficient() | |
ye := y.Exponent() | |
if x.IsZero() { | |
// Using y.IsZero() because decimal.New(0, 1_000_000).Equal(decimal.Zero) takes too long to execute | |
assert.Truef(t, y.IsZero(), "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye) | |
assert.Zerof(t, ye, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye) | |
return | |
} | |
assert.Truef(t, y.Equal(x), "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye) | |
assert.GreaterOrEqualf(t, ye, xe, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye) | |
if new(big.Int).Rem(yc, big.NewInt(10)).Sign() == 0 { | |
assert.EqualValuesf(t, math.MaxInt32, ye, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye) | |
} | |
}) | |
} | |
var xxx decimal.Decimal | |
// 225.0 ns/op | |
func BenchmarkNormalize(b *testing.B) { | |
r := rand.New(rand.NewSource(time.Now().Unix())) | |
var x decimal.Decimal | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
b.StopTimer() | |
d := decimal.New(r.Int63(), r.Int31()) | |
b.StartTimer() | |
x = util.Normalize(d) | |
} | |
xxx = x | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment