-
-
Save tgascoigne/f8d6c6538a5841bcb5f135668279b93b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package script | |
import ( | |
"errors" | |
"fmt" | |
"sync" | |
tengo "github.com/d5/tengo/v2" | |
"github.com/d5/tengo/v2/parser" | |
) | |
type Script struct { | |
mu sync.RWMutex | |
bytecode *tengo.Bytecode | |
symbols *tengo.SymbolTable | |
indexes map[string]int | |
globals []tengo.Object | |
outIdx int | |
src []byte | |
variables map[string]*tengo.Variable | |
modules *tengo.ModuleMap | |
} | |
func NewScript(src []byte) *Script { | |
return &Script{ | |
src: src, | |
variables: make(map[string]*tengo.Variable), | |
} | |
} | |
func (s *Script) prepCompile() ( | |
symbolTable *tengo.SymbolTable, | |
globals []tengo.Object, | |
err error, | |
) { | |
var names []string | |
for name := range s.variables { | |
names = append(names, name) | |
} | |
symbolTable = tengo.NewSymbolTable() | |
globals = make([]tengo.Object, tengo.GlobalsSize) | |
for idx, name := range names { | |
symbol := symbolTable.Define(name) | |
if symbol.Index != idx { | |
panic(fmt.Errorf("wrong symbol index: %d != %d", | |
idx, symbol.Index)) | |
} | |
globals[symbol.Index] = s.variables[name].Object() | |
} | |
return | |
} | |
func (s *Script) Compile() error { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
var err error | |
var symbolTable *tengo.SymbolTable | |
symbolTable, s.globals, err = s.prepCompile() | |
if err != nil { | |
return err | |
} | |
fileSet := parser.NewFileSet() | |
srcFile := fileSet.AddFile("(main)", -1, len(s.src)) | |
p := parser.NewParser(srcFile, s.src, nil) | |
file, err := p.ParseFile() | |
if err != nil { | |
return err | |
} | |
out := symbolTable.Define("$out") | |
s.globals[out.Index] = tengo.UndefinedValue | |
cc := tengo.NewCompiler(srcFile, symbolTable, nil, s.modules, nil) | |
if err := cc.Compile(file); err != nil { | |
return err | |
} | |
// reduce globals size | |
s.globals = s.globals[:symbolTable.MaxSymbols()+1] | |
// global symbol names to indexes | |
globalIndexes := make(map[string]int, len(s.globals)) | |
for _, name := range symbolTable.Names() { | |
symbol, _, _ := symbolTable.Resolve(name) | |
if symbol.Scope == tengo.ScopeGlobal { | |
globalIndexes[name] = symbol.Index | |
} | |
} | |
vm := tengo.NewVM(cc.Bytecode(), s.globals, -1) | |
// fill globals, assume script only has function definitions | |
if err := vm.Run(); err != nil { | |
return err | |
} | |
bc := cc.Bytecode() | |
bc.RemoveDuplicates() | |
s.bytecode = bc | |
s.symbols = symbolTable | |
s.outIdx = out.Index | |
s.indexes = globalIndexes | |
return nil | |
} | |
func (s *Script) CallByName(fn string, args ...interface{}) (interface{}, error) { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
idx, ok := s.indexes[fn] | |
if !ok { | |
return nil, errors.New("not found") | |
} | |
cfn, ok := s.globals[idx].(*tengo.CompiledFunction) | |
if !ok { | |
return nil, errors.New("not a compiled function") | |
} | |
return s.call(cfn, args...) | |
} | |
func (s *Script) Call(fn tengo.Object, | |
args ...interface{}) (interface{}, error) { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
cfn, ok := fn.(*tengo.CompiledFunction) | |
if !ok { | |
return nil, errors.New("not a compiled function") | |
} | |
return s.call(cfn, args...) | |
} | |
func (s *Script) call(cfn *tengo.CompiledFunction, | |
args ...interface{}) (interface{}, error) { | |
targs := make([]tengo.Object, 0, len(args)) | |
for i := range args { | |
v, err := tengo.FromInterface(args[i]) | |
if err != nil { | |
return nil, err | |
} | |
targs = append(targs, v) | |
} | |
v, err := s.callCompiled(cfn, targs...) | |
if err != nil { | |
return nil, err | |
} | |
return tengo.ToInterface(v), nil | |
} | |
func (s *Script) callCompiled(fn *tengo.CompiledFunction, | |
args ...tengo.Object) (tengo.Object, error) { | |
constsOffset := len(s.bytecode.Constants) | |
// Load fn | |
inst := tengo.MakeInstruction(parser.OpConstant, constsOffset) | |
// Load args | |
for i := range args { | |
inst = append(inst, | |
tengo.MakeInstruction(parser.OpConstant, constsOffset+i+1)...) | |
} | |
// Call, set value to a global, stop | |
inst = append(inst, tengo.MakeInstruction(parser.OpCall, len(args))...) | |
inst = append(inst, tengo.MakeInstruction(parser.OpSetGlobal, s.outIdx)...) | |
inst = append(inst, tengo.MakeInstruction(parser.OpSuspend)...) | |
s.bytecode.Constants = append(s.bytecode.Constants, fn) | |
s.bytecode.Constants = append(s.bytecode.Constants, args...) | |
orig := s.bytecode.MainFunction | |
s.bytecode.MainFunction = &tengo.CompiledFunction{ | |
Instructions: inst, | |
} | |
vm := tengo.NewVM(s.bytecode, s.globals, -1) | |
err := vm.Run() | |
// go back to normal if required | |
s.bytecode.MainFunction = orig | |
s.bytecode.Constants = s.bytecode.Constants[:constsOffset] | |
// get symbol using index and return it | |
return s.globals[s.outIdx], err | |
} | |
type Callback struct { | |
script *Script | |
fn tengo.Object | |
} | |
func (s *Script) MakeCallback(fn tengo.Object) *Callback { | |
return &Callback{ | |
script: s, | |
fn: fn, | |
} | |
} | |
func (c *Callback) Call(args ...interface{}) (interface{}, error) { | |
return c.script.Call(c.fn, args...) | |
} | |
// Add adds a new variable or updates an existing variable to the script. | |
func (s *Script) Add(name string, value interface{}) error { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
var err error | |
s.variables[name], err = tengo.NewVariable(name, value) | |
return err | |
} | |
// Remove removes (undefines) an existing variable for the script. It returns | |
// false if the variable name is not defined. | |
func (s *Script) Remove(name string) bool { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
if _, ok := s.variables[name]; !ok { | |
return false | |
} | |
delete(s.variables, name) | |
return true | |
} | |
// SetImports sets import modules. | |
func (s *Script) SetImports(modules *tengo.ModuleMap) { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
s.modules = modules | |
} | |
// Get returns a variable identified by the name. | |
func (s *Script) Get(name string) *tengo.Variable { | |
s.mu.RLock() | |
defer s.mu.RUnlock() | |
value := tengo.UndefinedValue | |
if idx, ok := s.indexes[name]; ok { | |
value = s.globals[idx] | |
if value == nil { | |
value = tengo.UndefinedValue | |
} | |
} | |
v, err := tengo.NewVariable(name, value) | |
if err != nil { | |
// This should never happen, since the value will already be a tengo object | |
// tengo's (*Script).Get avoids this because it constructs the object directly. | |
panic(fmt.Errorf("unable to create new variable from global value, type was %T", value)) | |
} | |
return v | |
} | |
// Set replaces the value of a global variable identified by the name. An error | |
// will be returned if the name was not defined during compilation. | |
func (s *Script) Set(name string, value interface{}) error { | |
s.mu.Lock() | |
defer s.mu.Unlock() | |
obj, err := tengo.FromInterface(value) | |
if err != nil { | |
return err | |
} | |
idx, ok := s.indexes[name] | |
if !ok { | |
return fmt.Errorf("'%s' is not defined", name) | |
} | |
s.globals[idx] = obj | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package script | |
import ( | |
"testing" | |
"github.com/d5/tengo/assert" | |
"github.com/d5/tengo/v2" | |
) | |
func TestCallCompiledFunc(t *testing.T) { | |
const mathModule = ` | |
add := func(a, b) { | |
return a + b | |
} | |
mul := func(a, b) { | |
return a * b | |
} | |
square := func(a) { | |
return mul(a, a) | |
} | |
` | |
scr := NewScript([]byte(mathModule)) | |
err := scr.Compile() | |
assert.NoError(t, err) | |
result, err := scr.CallByName("add", 3, 4) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(7), result) | |
result, err = scr.CallByName("mul", 3, 4) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(12), result) | |
result, err = scr.CallByName("square", 3) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(9), result) | |
} | |
func TestCallback(t *testing.T) { | |
const callbackModule = ` | |
b := 2 | |
pass(func(a) { | |
return a * b | |
}) | |
` | |
scr := NewScript([]byte(callbackModule)) | |
var callback *Callback | |
scr.Add("pass", &tengo.UserFunction{ | |
Value: func(args ...tengo.Object) (tengo.Object, error) { | |
callback = scr.MakeCallback(args[0]) | |
return tengo.UndefinedValue, nil | |
}, | |
}) | |
err := scr.Compile() | |
assert.NoError(t, err) | |
result, err := callback.Call(3) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(6), result) | |
result, err = callback.Call(5) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(10), result) | |
// Modify the global and check the new value is reflected in the function | |
scr.Set("b", 3) | |
result, err = callback.Call(5) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(15), result) | |
} | |
func TestClosure(t *testing.T) { | |
const closureModule = ` | |
mulClosure := func(a) { | |
return func(b) { | |
return a * b | |
} | |
} | |
mul2 := mulClosure(2) | |
mul3 := mulClosure(3) | |
` | |
scr := NewScript([]byte(closureModule)) | |
err := scr.Compile() | |
assert.NoError(t, err) | |
result, err := scr.CallByName("mul2", 3) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(6), result) | |
result, err = scr.CallByName("mul2", 5) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(10), result) | |
result, err = scr.CallByName("mul3", 5) | |
assert.NoError(t, err) | |
assert.Equal(t, int64(15), result) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment