Created
April 26, 2013 15:58
-
-
Save tjwei/5468350 to your computer and use it in GitHub Desktop.
Fake operator overload for big.Int in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"go/ast" | |
"go/parser" | |
"go/token" | |
"math/big" | |
"reflect" | |
"strings" | |
) | |
//type IntFunc func(...*big.Int) *big.Int | |
type IntFunc interface{} | |
func NewEvalEnv(Vars map[string]*big.Int, Funcs map[string]IntFunc) func(string) *big.Int { | |
var evalx func(node ast.Node) *big.Int | |
tf := func(b bool) *big.Int { | |
if b { | |
return big.NewInt(1) | |
} | |
return big.NewInt(0) | |
} | |
opMap := map[token.Token](func(*big.Int, *big.Int, *big.Int) *big.Int){ | |
token.ADD: (*big.Int).Add, | |
token.SUB: (*big.Int).Sub, | |
token.MUL: (*big.Int).Mul, | |
token.QUO: (*big.Int).Quo, | |
token.REM: (*big.Int).Rem, | |
token.XOR: func(r, a, b *big.Int) *big.Int { return r.Exp(a, b, nil) }, | |
token.SHL: func(r, a, b *big.Int) *big.Int { return r.Lsh(a, uint(b.Int64())) }, | |
token.SHR: func(r, a, b *big.Int) *big.Int { return r.Rsh(a, uint(b.Int64())) }, | |
token.GEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) >= 0) }, | |
token.LEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) <= 0) }, | |
token.GTR: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) > 0) }, | |
token.LSS: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) < 0) }, | |
token.EQL: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) == 0) }, | |
token.NEQ: func(_, a, b *big.Int) *big.Int { return tf(a.Cmp(b) != 0) }, | |
} | |
evalx = func(node ast.Node) *big.Int { | |
switch n := node.(type) { | |
case *ast.BasicLit: | |
switch n.Kind { | |
case token.INT: | |
rtn := new(big.Int) | |
_, ok := rtn.SetString(n.Value, 0) | |
if !ok { | |
panic("can not convert " + n.Value) | |
} | |
return rtn | |
default: | |
panic("unkown basic lit" + fmt.Sprint("[", n.Kind, n.Value) + "]") | |
} | |
case *ast.CallExpr: | |
switch nf := n.Fun.(type) { | |
case *ast.Ident: | |
rtn := new(big.Int) | |
fn, ok := Funcs[nf.Name] | |
fnValue := reflect.ValueOf(fn) | |
//args := make([]*big.Int, 0, len(n.Args)) | |
args := make([]reflect.Value, 0, len(n.Args)) | |
if !ok { | |
fnValue = reflect.ValueOf(rtn).MethodByName(nf.Name) | |
if fnValue.Kind() == reflect.Invalid { | |
panic("unknown function") | |
} | |
//panic("function undefined") | |
} | |
for _, v := range n.Args { | |
args = append(args, reflect.ValueOf(evalx(v))) | |
} | |
//return fnValue(args...) | |
rtnValue := fnValue.Call(args)[0].Interface() | |
return rtnValue.(*big.Int) | |
default: | |
fmt.Sprint("function is too complicate", nf) | |
} | |
case *ast.BinaryExpr: | |
rtn := new(big.Int) | |
fn, ok := opMap[n.Op] | |
if ok { | |
return fn(rtn, evalx(n.X), evalx(n.Y)) | |
} | |
switch n.Op { | |
case token.LOR: | |
a := evalx(n.X) | |
if a.Sign() == 0 { | |
return evalx(n.Y) | |
} | |
return a | |
case token.LAND: | |
a := evalx(n.X) | |
if a.Sign() == 0 { | |
return a | |
} | |
return evalx(n.Y) | |
} | |
panic(n.Op) | |
case *ast.UnaryExpr: | |
switch n.Op { | |
case token.ADD: | |
return evalx(n.X) | |
case token.SUB: | |
rtn := new(big.Int) | |
return rtn.Neg(evalx(n.X)) | |
case token.NOT: | |
a := evalx(n.X) | |
if a.Sign() == 0 { | |
return big.NewInt(1) | |
} | |
return big.NewInt(0) | |
} | |
case *ast.Ident: | |
v, ok := Vars[n.Name] | |
if !ok { | |
panic("Ident[" + n.Name + "]") | |
} | |
return v | |
case *ast.ParenExpr: | |
return evalx(n.X) | |
default: | |
fmt.Println("not handled", n, reflect.TypeOf(n)) | |
} | |
fmt.Println("???", node, reflect.TypeOf(node)) | |
panic("???") | |
} | |
return func(s string) *big.Int { | |
pos := 0 | |
expr, err := parser.ParseExpr(s[pos:]) | |
if err != nil { | |
fmt.Println(expr) | |
panic(err.Error()) | |
} | |
pos = int(expr.End()) | |
if pos <= len(s) { | |
fmt.Println("string longer than the expression", pos, len(s)) | |
panic("...") | |
} | |
return evalx(expr) | |
} | |
} | |
func Bool(v *big.Int) bool { | |
return v.Cmp(big.NewInt(0)) != 0 | |
} | |
func main() { | |
V := make(map[string]*big.Int) | |
F := make(map[string]IntFunc) | |
E := NewEvalEnv(V, F) | |
lambda := func(params, formula string) func(...*big.Int) *big.Int { | |
paramNames := strings.FieldsFunc(params, func(x rune) bool { return x == ',' || x == ' ' }) | |
return func(v ...*big.Int) *big.Int { | |
_Vars := make(map[string]*big.Int) | |
for i := range v { | |
_Vars[paramNames[i]] = v[i] | |
} | |
return NewEvalEnv(_Vars, F)(formula) | |
} | |
} | |
F["exp"] = func(n ...*big.Int) *big.Int { | |
return new(big.Int).Exp(n[0], n[1], n[2]) | |
} | |
fmt.Println("(50000000000+3)*2*(5^10)=", E("(50000000000+3)*2*(5^10)")) | |
V["x"] = E("1") | |
for V["i"] = E("1"); E("i-10").Sign() < 0; V["i"] = E("i+1") { | |
V["x"] = E("x*i") | |
} | |
fmt.Println("9!=", E("x")) | |
V["x"] = E("1") | |
for V["i"] = E("1"); Bool(E("i<10")); V["i"] = E("i+1") { | |
V["x"] = E("x*i") | |
} | |
fmt.Println("9!=", E("x")) | |
fac := lambda("n", " n==0 || n*fac(n-1)") | |
F["fac"] = fac | |
fmt.Println("10!=", E("fac(10)")) | |
fmt.Println("10!=", fac(big.NewInt(10))) | |
fmt.Println("exp(2,16,100)=", E("exp(2,16,100)")) | |
fmt.Println("Exp(2,100,100)=", E("Exp(2,100,100)")) | |
fmt.Println("100!=", fac(E("100"))) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment