Skip to content

Instantly share code, notes, and snippets.

@maratori
Last active September 13, 2022 06:08
Show Gist options
  • Save maratori/15cf2a41d52ac6bd27dfced7922f0d28 to your computer and use it in GitHub Desktop.
Save maratori/15cf2a41d52ac6bd27dfced7922f0d28 to your computer and use it in GitHub Desktop.
Function to normalize decimal.Decimal in Go
package util
import (
"math"
"math/big"
"github.com/shopspring/decimal"
)
var /* const */ ten = big.NewInt(10)
var /* const */ zero = decimal.New(0, 0)
// Normalize returns [decimal.Decimal] equal to input d, but rescaled.
// [decimal.Decimal.Exponent] is changed such a way to remove all trailing zeroes
// from decimal representation of [decimal.Decimal.Coefficient].
// As a special case zero decimal has zero [decimal.Decimal.Exponent].
//
// Examples
//
// Normalize(decimal.New(100, 0)) -> decimal.New(1, 2)
// Normalize(decimal.New(100, -3)) -> decimal.New(1, -1)
// Normalize(decimal.New(123, -2)) -> decimal.New(123, -2)
// Normalize(decimal.New(0, 1)) -> decimal.New(0, 0)
func Normalize(d decimal.Decimal) decimal.Decimal {
if d.IsZero() {
return zero
}
value := d.Coefficient()
remainder := new(big.Int)
mayValue := new(big.Int)
n := int64(0)
maxN := math.MaxInt32 - int64(d.Exponent())
for {
if n == maxN {
break // avoid int32 overflow
}
mayValue.QuoRem(value, ten, remainder)
if remainder.Sign() != 0 {
break
}
value.Set(mayValue)
n++
}
if n == 0 {
return d
}
return decimal.NewFromBigInt(value, d.Exponent()+int32(n))
}
package util_test
import (
"math"
"math/big"
"math/rand"
"testing"
"time"
"github.com/maratori/pairedbrackets/util"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
)
func TestNormalize(t *testing.T) {
t.Parallel()
for in, out := range map[decimal.Decimal]decimal.Decimal{
decimal.New(0, 0): decimal.New(0, 0),
decimal.New(0, 10): decimal.New(0, 0),
decimal.New(0, -10): decimal.New(0, 0),
decimal.New(10, 0): decimal.New(1, 1),
decimal.New(100, 0): decimal.New(1, 2),
decimal.New(1000, 0): decimal.New(1, 3),
decimal.New(10, 1): decimal.New(1, 2),
decimal.New(100, 1): decimal.New(1, 3),
decimal.New(1000, 1): decimal.New(1, 4),
decimal.New(10, -1): decimal.New(1, 0),
decimal.New(100, -1): decimal.New(1, 1),
decimal.New(1000, -1): decimal.New(1, 2),
decimal.New(-10, 0): decimal.New(-1, 1),
decimal.New(-100, 0): decimal.New(-1, 2),
decimal.New(-1000, 0): decimal.New(-1, 3),
decimal.New(-10, 1): decimal.New(-1, 2),
decimal.New(-100, 1): decimal.New(-1, 3),
decimal.New(-1000, 1): decimal.New(-1, 4),
decimal.New(-10, -1): decimal.New(-1, 0),
decimal.New(-100, -1): decimal.New(-1, 1),
decimal.New(-1000, -1): decimal.New(-1, 2),
decimal.New(1234, 2): decimal.New(1234, 2),
decimal.New(1234, -2): decimal.New(1234, -2),
decimal.New(1234, math.MaxInt32): decimal.New(1234, math.MaxInt32),
decimal.New(12340, math.MaxInt32): decimal.New(12340, math.MaxInt32),
decimal.New(123400, math.MaxInt32-1): decimal.New(12340, math.MaxInt32),
} {
actual := util.Normalize(in)
assert.Equal(t, out, actual)
}
}
func FuzzNormalize(f *testing.F) {
for _, val := range []int64{0, 1, 10, 100, 1000, 1_000_000_000} {
for _, valS := range []int64{1, -1} {
for _, exp := range []int32{0, 1, 10, 100, 1000, 1_000_000_000} {
for _, expS := range []int32{1, -1} {
f.Add(val*valS, exp*expS)
}
}
}
}
for _, val := range []int64{math.MinInt64, math.MaxInt64, 1, 10, 100, 1000, -1, -10, -100, -1000} {
for _, exp := range []int32{math.MinInt32, math.MaxInt32, math.MinInt32 + 1, math.MaxInt32 - 1} {
f.Add(val, exp)
}
}
f.Fuzz(func(t *testing.T, val int64, exp int32) {
x := decimal.New(val, exp)
y := util.Normalize(x)
xc := x.Coefficient()
xe := x.Exponent()
yc := y.Coefficient()
ye := y.Exponent()
if x.IsZero() {
// Using y.IsZero() because decimal.New(0, 1_000_000).Equal(decimal.Zero) takes too long to execute
assert.Truef(t, y.IsZero(), "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye)
assert.Zerof(t, ye, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye)
return
}
assert.Truef(t, y.Equal(x), "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye)
assert.GreaterOrEqualf(t, ye, xe, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye)
if new(big.Int).Rem(yc, big.NewInt(10)).Sign() == 0 {
assert.EqualValuesf(t, math.MaxInt32, ye, "x = %s * 10 ^ %d\ny = %s * 10 ^ %d", xc, xe, yc, ye)
}
})
}
var xxx decimal.Decimal
// 225.0 ns/op
func BenchmarkNormalize(b *testing.B) {
r := rand.New(rand.NewSource(time.Now().Unix()))
var x decimal.Decimal
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
d := decimal.New(r.Int63(), r.Int31())
b.StartTimer()
x = util.Normalize(d)
}
xxx = x
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment