Skip to content

Instantly share code, notes, and snippets.

@tjwei
Created April 26, 2013 15:58
Show Gist options
  • Save tjwei/5468350 to your computer and use it in GitHub Desktop.
Save tjwei/5468350 to your computer and use it in GitHub Desktop.
Fake operator overload for big.Int in Go
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"math/big"
"reflect"
"strings"
)
//type IntFunc func(...*big.Int) *big.Int
type IntFunc interface{}
func NewEvalEnv(Vars map[string]*big.Int, Funcs map[string]IntFunc) func(string) *big.Int {
var evalx func(node ast.Node) *big.Int
tf := func(b bool) *big.Int {
if b {
return big.NewInt(1)
}
return big.NewInt(0)
}
opMap := map[token.Token](func(*big.Int, *big.Int, *big.Int) *big.Int){
token.ADD: (*big.Int).Add,
token.SUB: (*big.Int).Sub,
token.MUL: (*big.Int).Mul,
token.QUO: (*big.Int).Quo,
token.REM: (*big.Int).Rem,
token.XOR: func(r, a, b *big.Int) *big.Int { return r.Exp(a, b, nil) },
token.SHL: func(r, a, b *big.Int) *big.Int { return r.Lsh(a, uint(b.Int64())) },
token.SHR: func(r, a, b *big.Int) *big.Int { return r.Rsh(a, uint(b.Int64())) },
token.GEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) >= 0) },
token.LEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) <= 0) },
token.GTR: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) > 0) },
token.LSS: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) < 0) },
token.EQL: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) == 0) },
token.NEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) != 0) },
}
evalx = func(node ast.Node) *big.Int {
switch n := node.(type) {
case *ast.BasicLit:
switch n.Kind {
case token.INT:
rtn := new(big.Int)
_, ok := rtn.SetString(n.Value, 0)
if !ok {
panic("can not convert " + n.Value)
}
return rtn
default:
panic("unkown basic lit" + fmt.Sprint("[", n.Kind, n.Value) + "]")
}
case *ast.CallExpr:
switch nf := n.Fun.(type) {
case *ast.Ident:
rtn := new(big.Int)
fn, ok := Funcs[nf.Name]
fnValue := reflect.ValueOf(fn)
//args := make([]*big.Int, 0, len(n.Args))
args := make([]reflect.Value, 0, len(n.Args))
if !ok {
fnValue = reflect.ValueOf(rtn).MethodByName(nf.Name)
if fnValue.Kind() == reflect.Invalid {
panic("unknown function")
}
//panic("function undefined")
}
for _, v := range n.Args {
args = append(args, reflect.ValueOf(evalx(v)))
}
//return fnValue(args...)
rtnValue := fnValue.Call(args)[0].Interface()
return rtnValue.(*big.Int)
default:
fmt.Sprint("function is too complicate", nf)
}
case *ast.BinaryExpr:
rtn := new(big.Int)
fn, ok := opMap[n.Op]
if ok {
return fn(rtn, evalx(n.X), evalx(n.Y))
}
switch n.Op {
case token.LOR:
a := evalx(n.X)
if a.Sign() == 0 {
return evalx(n.Y)
}
return a
case token.LAND:
a := evalx(n.X)
if a.Sign() == 0 {
return a
}
return evalx(n.Y)
}
panic(n.Op)
case *ast.UnaryExpr:
switch n.Op {
case token.ADD:
return evalx(n.X)
case token.SUB:
rtn := new(big.Int)
return rtn.Neg(evalx(n.X))
case token.NOT:
a := evalx(n.X)
if a.Sign() == 0 {
return big.NewInt(1)
}
return big.NewInt(0)
}
case *ast.Ident:
v, ok := Vars[n.Name]
if !ok {
panic("Ident[" + n.Name + "]")
}
return v
case *ast.ParenExpr:
return evalx(n.X)
default:
fmt.Println("not handled", n, reflect.TypeOf(n))
}
fmt.Println("???", node, reflect.TypeOf(node))
panic("???")
}
return func(s string) *big.Int {
pos := 0
expr, err := parser.ParseExpr(s[pos:])
if err != nil {
fmt.Println(expr)
panic(err.Error())
}
pos = int(expr.End())
if pos <= len(s) {
fmt.Println("string longer than the expression", pos, len(s))
panic("...")
}
return evalx(expr)
}
}
func Bool(v *big.Int) bool {
return v.Cmp(big.NewInt(0)) != 0
}
func main() {
V := make(map[string]*big.Int)
F := make(map[string]IntFunc)
E := NewEvalEnv(V, F)
lambda := func(params, formula string) func(...*big.Int) *big.Int {
paramNames := strings.FieldsFunc(params, func(x rune) bool { return x == ',' || x == ' ' })
return func(v ...*big.Int) *big.Int {
_Vars := make(map[string]*big.Int)
for i := range v {
_Vars[paramNames[i]] = v[i]
}
return NewEvalEnv(_Vars, F)(formula)
}
}
F["exp"] = func(n ...*big.Int) *big.Int {
return new(big.Int).Exp(n[0], n[1], n[2])
}
fmt.Println("(50000000000+3)*2*(5^10)=", E("(50000000000+3)*2*(5^10)"))
V["x"] = E("1")
for V["i"] = E("1"); E("i-10").Sign() < 0; V["i"] = E("i+1") {
V["x"] = E("x*i")
}
fmt.Println("9!=", E("x"))
V["x"] = E("1")
for V["i"] = E("1"); Bool(E("i<10")); V["i"] = E("i+1") {
V["x"] = E("x*i")
}
fmt.Println("9!=", E("x"))
fac := lambda("n", " n==0 || n*fac(n-1)")
F["fac"] = fac
fmt.Println("10!=", E("fac(10)"))
fmt.Println("10!=", fac(big.NewInt(10)))
fmt.Println("exp(2,16,100)=", E("exp(2,16,100)"))
fmt.Println("Exp(2,100,100)=", E("Exp(2,100,100)"))
fmt.Println("100!=", fac(E("100")))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment