Last active
October 27, 2023 07:40
-
-
Save maxpushka/e8b02eb86e1c2d60baf25350bad2606f to your computer and use it in GitHub Desktop.
decimal.Decimal to significant figures converter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package digits_decimals | |
import ( | |
"strings" | |
"unsafe" | |
"github.com/shopspring/decimal" | |
) | |
type params struct { | |
digitsBefore int | |
dot int | |
leadingZeros int | |
roundTo int | |
significantDigits int | |
maxDecimals int | |
} | |
func ToSignificant(x decimal.Decimal, significantDigits, maxDecimals uint) decimal.Decimal { | |
if x.IsZero() { | |
return x | |
} | |
bytes := []byte(x.String()) | |
dot := strings.IndexRune(unsafeToString(bytes), '.') | |
p := params{ | |
digitsBefore: getDigitsBefore(bytes, dot), | |
dot: dot, | |
significantDigits: int(significantDigits), | |
maxDecimals: int(maxDecimals), | |
} | |
if p.digitsBefore > p.significantDigits { | |
roundInt(bytes, p.digitsBefore-p.significantDigits) | |
return buildDecimal(bytes) | |
} else if p.digitsBefore > 0 && p.dot == -1 { | |
// For some minuscule numbers like `math.SmallestNonzeroFloat64` | |
// string representation may be different from its decimal representation, | |
// So it's better to rebuild it from scratch. | |
return buildDecimal(bytes) | |
} | |
if truncated := truncateFloat(bytes, p); truncated != nil { | |
bytes = truncated | |
} | |
p.leadingZeros = getLeadingZeros(bytes, p.dot) | |
// Calculate the position of the last significant digit | |
if p.digitsBefore > 0 { | |
p.roundTo = int(significantDigits) - p.digitsBefore - p.leadingZeros | |
if p.roundTo < 0 { | |
p.roundTo = p.dot + 1 | |
} | |
} else { | |
p.roundTo = p.dot + p.leadingZeros + int(significantDigits) | |
if p.roundTo > int(maxDecimals) { | |
p.roundTo = int(maxDecimals) | |
} | |
} | |
// Make sure rounding occurs inside buffer | |
if p.roundTo > len(bytes) { // e.g. 0.0001 | |
p.roundTo = len(bytes) - 1 | |
} | |
roundFloat(bytes, p) | |
return buildDecimal(bytes) | |
} | |
// unsafeToString effectively tricks the Go runtime into treating | |
// the underlying memory of the byte slice as if it were a string | |
// to avoid additional allocation. | |
// | |
// !!! DO NOT MODIFY STRING YOU GET BACK FROM THIS FUNCTION !!! | |
// | |
// Since strings in Go are immutable, | |
// **YOU MUST ENSURE THAT THE ORIGINAL BYTE SLICE ISN'T MODIFIED** after this conversion, | |
// or else you risk causing undefined behavior. | |
func unsafeToString(b []byte) string { | |
return *(*string)(unsafe.Pointer(&b)) | |
} | |
func getDigitsBefore(bytes []byte, dot int) int { | |
if dot == -1 { | |
return len(bytes) | |
} | |
var digitsBefore int | |
for _, b := range bytes[:dot] { | |
if b == '0' { | |
continue | |
} | |
digitsBefore++ | |
} | |
return digitsBefore | |
} | |
func getLeadingZeros(bytes []byte, dot int) int { | |
var leadingZeros int | |
for _, b := range bytes[dot+1:] { | |
if b != '0' { | |
continue | |
} | |
leadingZeros++ | |
} | |
return leadingZeros | |
} | |
// truncateFloat truncates decimals that are greater than max allowed decimals | |
func truncateFloat(bytes []byte, p params) []byte { | |
bytesBeforeDot := p.dot + 1 | |
bytesAfterDot := len(bytes) - bytesBeforeDot | |
if bytesAfterDot > p.maxDecimals { | |
return bytes[:bytesBeforeDot+p.maxDecimals] | |
} | |
return nil | |
} | |
func roundInt(bytes []byte, to int) { | |
bytesLen := len(bytes) | |
for i := bytesLen - 1; i > bytesLen-to-2; i-- { | |
if bytes[i] == '9' { | |
bytes[i] = '0' | |
} else { | |
bytes[i]++ | |
if i != bytesLen-1 { | |
bytes[i+1] = '0' | |
} | |
} | |
} | |
} | |
func roundFloat(bytes []byte, p params) { | |
existingSignificantDigits := len(bytes[p.dot+1+p.leadingZeros:]) | |
enoughDecimals := p.leadingZeros+existingSignificantDigits <= p.maxDecimals | |
if p.digitsBefore == 0 && enoughDecimals && existingSignificantDigits <= p.significantDigits { | |
// No rounding is required for 0.0001, | |
// as it already has enough significant digits | |
// and there's nothing to round | |
return | |
} | |
// Carry mechanism to round up | |
var carry bool | |
for i := len(bytes) - 1; i > p.roundTo-1; i-- { | |
if bytes[i] == '9' { | |
bytes[i] = '0' | |
carry = true | |
} else if bytes[i] >= '5' { | |
bytes[i] = '0' | |
carry = true | |
} else if carry { | |
bytes[i]++ | |
carry = false | |
} else { | |
bytes[i] = '0' | |
} | |
} | |
// If there's still a carry after the dot, increment the part before the dot | |
if carry { | |
i := p.dot - 1 | |
for i >= 0 { | |
if bytes[i] == '9' { | |
bytes[i] = '0' | |
i-- | |
} else { | |
bytes[i]++ | |
break | |
} | |
} | |
// If we've gone through all the bytes, and they were all '9' | |
if i == -1 { | |
bytes = append([]byte{'1'}, bytes...) | |
} | |
} | |
} | |
func buildDecimal(bytes []byte) decimal.Decimal { | |
result, err := decimal.NewFromString(unsafeToString(bytes)) | |
if err != nil { | |
panic("failed to parse converted bytes to decimal") | |
} | |
return result | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package digits_decimals | |
import ( | |
"fmt" | |
"math" | |
"testing" | |
"github.com/shopspring/decimal" | |
) | |
func TestToSignificant(t *testing.T) { | |
tests := []struct { | |
input decimal.Decimal | |
expected decimal.Decimal | |
}{ | |
{decimal.NewFromFloat(0.0), decimal.NewFromFloat(0.0)}, | |
{decimal.NewFromFloat(math.SmallestNonzeroFloat64), decimal.NewFromFloat(0.0)}, | |
{decimal.NewFromFloat(0.000123456), decimal.NewFromFloat(0.00012345)}, | |
{decimal.NewFromFloat(0.0001), decimal.NewFromFloat(0.0001)}, | |
{decimal.NewFromFloat(0.012344789), decimal.NewFromFloat(0.012345)}, | |
{decimal.NewFromFloat(0.012345), decimal.NewFromFloat(0.012345)}, | |
{decimal.NewFromFloat(1.00000234), decimal.NewFromFloat(1.0)}, | |
{decimal.NewFromInt(2), decimal.NewFromInt(2)}, | |
{decimal.NewFromInt(100006), decimal.NewFromInt(100010)}, | |
} | |
for _, tt := range tests { | |
tt := tt | |
t.Run(fmt.Sprintf("%s -> %s", tt.input, tt.expected), func(t *testing.T) { | |
actual := ToSignificant(tt.input, 5, 8) | |
if !actual.Equals(tt.expected) { | |
t.Errorf("ToSignificant(%s): expected %s, got %s", tt.input, tt.expected, actual) | |
} | |
}) | |
} | |
} | |
func BenchmarkToSignificant_DecimalWithLeadingZeros(b *testing.B) { | |
// Reset timer to exclude time taken by setup operations | |
// before the actual benchmark begins | |
b.ReportAllocs() | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
ToSignificant(decimal.NewFromFloat(0.000123456), 5, 8) | |
} | |
} | |
func BenchmarkToSignificant_TruncateDecimals(b *testing.B) { | |
// Reset timer to exclude time taken by setup operations | |
// before the actual benchmark begins | |
b.ReportAllocs() | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
ToSignificant(decimal.NewFromFloat(1.00000234), 5, 8) | |
} | |
} | |
func BenchmarkToSignificant_IntegralPartSizeGreaterThanSignificantDigits(b *testing.B) { | |
// Reset timer to exclude time taken by setup operations | |
// before the actual benchmark begins | |
b.ReportAllocs() | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
ToSignificant(decimal.NewFromInt(100006), 5, 8) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment