Skip to content

Instantly share code, notes, and snippets.

@tgascoigne
Created May 12, 2020 11:03
Show Gist options
  • Save tgascoigne/f8d6c6538a5841bcb5f135668279b93b to your computer and use it in GitHub Desktop.
Save tgascoigne/f8d6c6538a5841bcb5f135668279b93b to your computer and use it in GitHub Desktop.
package script
import (
"errors"
"fmt"
"sync"
tengo "github.com/d5/tengo/v2"
"github.com/d5/tengo/v2/parser"
)
type Script struct {
mu sync.RWMutex
bytecode *tengo.Bytecode
symbols *tengo.SymbolTable
indexes map[string]int
globals []tengo.Object
outIdx int
src []byte
variables map[string]*tengo.Variable
modules *tengo.ModuleMap
}
func NewScript(src []byte) *Script {
return &Script{
src: src,
variables: make(map[string]*tengo.Variable),
}
}
func (s *Script) prepCompile() (
symbolTable *tengo.SymbolTable,
globals []tengo.Object,
err error,
) {
var names []string
for name := range s.variables {
names = append(names, name)
}
symbolTable = tengo.NewSymbolTable()
globals = make([]tengo.Object, tengo.GlobalsSize)
for idx, name := range names {
symbol := symbolTable.Define(name)
if symbol.Index != idx {
panic(fmt.Errorf("wrong symbol index: %d != %d",
idx, symbol.Index))
}
globals[symbol.Index] = s.variables[name].Object()
}
return
}
func (s *Script) Compile() error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var symbolTable *tengo.SymbolTable
symbolTable, s.globals, err = s.prepCompile()
if err != nil {
return err
}
fileSet := parser.NewFileSet()
srcFile := fileSet.AddFile("(main)", -1, len(s.src))
p := parser.NewParser(srcFile, s.src, nil)
file, err := p.ParseFile()
if err != nil {
return err
}
out := symbolTable.Define("$out")
s.globals[out.Index] = tengo.UndefinedValue
cc := tengo.NewCompiler(srcFile, symbolTable, nil, s.modules, nil)
if err := cc.Compile(file); err != nil {
return err
}
// reduce globals size
s.globals = s.globals[:symbolTable.MaxSymbols()+1]
// global symbol names to indexes
globalIndexes := make(map[string]int, len(s.globals))
for _, name := range symbolTable.Names() {
symbol, _, _ := symbolTable.Resolve(name)
if symbol.Scope == tengo.ScopeGlobal {
globalIndexes[name] = symbol.Index
}
}
vm := tengo.NewVM(cc.Bytecode(), s.globals, -1)
// fill globals, assume script only has function definitions
if err := vm.Run(); err != nil {
return err
}
bc := cc.Bytecode()
bc.RemoveDuplicates()
s.bytecode = bc
s.symbols = symbolTable
s.outIdx = out.Index
s.indexes = globalIndexes
return nil
}
func (s *Script) CallByName(fn string, args ...interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
idx, ok := s.indexes[fn]
if !ok {
return nil, errors.New("not found")
}
cfn, ok := s.globals[idx].(*tengo.CompiledFunction)
if !ok {
return nil, errors.New("not a compiled function")
}
return s.call(cfn, args...)
}
func (s *Script) Call(fn tengo.Object,
args ...interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
cfn, ok := fn.(*tengo.CompiledFunction)
if !ok {
return nil, errors.New("not a compiled function")
}
return s.call(cfn, args...)
}
func (s *Script) call(cfn *tengo.CompiledFunction,
args ...interface{}) (interface{}, error) {
targs := make([]tengo.Object, 0, len(args))
for i := range args {
v, err := tengo.FromInterface(args[i])
if err != nil {
return nil, err
}
targs = append(targs, v)
}
v, err := s.callCompiled(cfn, targs...)
if err != nil {
return nil, err
}
return tengo.ToInterface(v), nil
}
func (s *Script) callCompiled(fn *tengo.CompiledFunction,
args ...tengo.Object) (tengo.Object, error) {
constsOffset := len(s.bytecode.Constants)
// Load fn
inst := tengo.MakeInstruction(parser.OpConstant, constsOffset)
// Load args
for i := range args {
inst = append(inst,
tengo.MakeInstruction(parser.OpConstant, constsOffset+i+1)...)
}
// Call, set value to a global, stop
inst = append(inst, tengo.MakeInstruction(parser.OpCall, len(args))...)
inst = append(inst, tengo.MakeInstruction(parser.OpSetGlobal, s.outIdx)...)
inst = append(inst, tengo.MakeInstruction(parser.OpSuspend)...)
s.bytecode.Constants = append(s.bytecode.Constants, fn)
s.bytecode.Constants = append(s.bytecode.Constants, args...)
orig := s.bytecode.MainFunction
s.bytecode.MainFunction = &tengo.CompiledFunction{
Instructions: inst,
}
vm := tengo.NewVM(s.bytecode, s.globals, -1)
err := vm.Run()
// go back to normal if required
s.bytecode.MainFunction = orig
s.bytecode.Constants = s.bytecode.Constants[:constsOffset]
// get symbol using index and return it
return s.globals[s.outIdx], err
}
type Callback struct {
script *Script
fn tengo.Object
}
func (s *Script) MakeCallback(fn tengo.Object) *Callback {
return &Callback{
script: s,
fn: fn,
}
}
func (c *Callback) Call(args ...interface{}) (interface{}, error) {
return c.script.Call(c.fn, args...)
}
// Add adds a new variable or updates an existing variable to the script.
func (s *Script) Add(name string, value interface{}) error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
s.variables[name], err = tengo.NewVariable(name, value)
return err
}
// Remove removes (undefines) an existing variable for the script. It returns
// false if the variable name is not defined.
func (s *Script) Remove(name string) bool {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.variables[name]; !ok {
return false
}
delete(s.variables, name)
return true
}
// SetImports sets import modules.
func (s *Script) SetImports(modules *tengo.ModuleMap) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules = modules
}
// Get returns a variable identified by the name.
func (s *Script) Get(name string) *tengo.Variable {
s.mu.RLock()
defer s.mu.RUnlock()
value := tengo.UndefinedValue
if idx, ok := s.indexes[name]; ok {
value = s.globals[idx]
if value == nil {
value = tengo.UndefinedValue
}
}
v, err := tengo.NewVariable(name, value)
if err != nil {
// This should never happen, since the value will already be a tengo object
// tengo's (*Script).Get avoids this because it constructs the object directly.
panic(fmt.Errorf("unable to create new variable from global value, type was %T", value))
}
return v
}
// Set replaces the value of a global variable identified by the name. An error
// will be returned if the name was not defined during compilation.
func (s *Script) Set(name string, value interface{}) error {
s.mu.Lock()
defer s.mu.Unlock()
obj, err := tengo.FromInterface(value)
if err != nil {
return err
}
idx, ok := s.indexes[name]
if !ok {
return fmt.Errorf("'%s' is not defined", name)
}
s.globals[idx] = obj
return nil
}
package script
import (
"testing"
"github.com/d5/tengo/assert"
"github.com/d5/tengo/v2"
)
func TestCallCompiledFunc(t *testing.T) {
const mathModule = `
add := func(a, b) {
return a + b
}
mul := func(a, b) {
return a * b
}
square := func(a) {
return mul(a, a)
}
`
scr := NewScript([]byte(mathModule))
err := scr.Compile()
assert.NoError(t, err)
result, err := scr.CallByName("add", 3, 4)
assert.NoError(t, err)
assert.Equal(t, int64(7), result)
result, err = scr.CallByName("mul", 3, 4)
assert.NoError(t, err)
assert.Equal(t, int64(12), result)
result, err = scr.CallByName("square", 3)
assert.NoError(t, err)
assert.Equal(t, int64(9), result)
}
func TestCallback(t *testing.T) {
const callbackModule = `
b := 2
pass(func(a) {
return a * b
})
`
scr := NewScript([]byte(callbackModule))
var callback *Callback
scr.Add("pass", &tengo.UserFunction{
Value: func(args ...tengo.Object) (tengo.Object, error) {
callback = scr.MakeCallback(args[0])
return tengo.UndefinedValue, nil
},
})
err := scr.Compile()
assert.NoError(t, err)
result, err := callback.Call(3)
assert.NoError(t, err)
assert.Equal(t, int64(6), result)
result, err = callback.Call(5)
assert.NoError(t, err)
assert.Equal(t, int64(10), result)
// Modify the global and check the new value is reflected in the function
scr.Set("b", 3)
result, err = callback.Call(5)
assert.NoError(t, err)
assert.Equal(t, int64(15), result)
}
func TestClosure(t *testing.T) {
const closureModule = `
mulClosure := func(a) {
return func(b) {
return a * b
}
}
mul2 := mulClosure(2)
mul3 := mulClosure(3)
`
scr := NewScript([]byte(closureModule))
err := scr.Compile()
assert.NoError(t, err)
result, err := scr.CallByName("mul2", 3)
assert.NoError(t, err)
assert.Equal(t, int64(6), result)
result, err = scr.CallByName("mul2", 5)
assert.NoError(t, err)
assert.Equal(t, int64(10), result)
result, err = scr.CallByName("mul3", 5)
assert.NoError(t, err)
assert.Equal(t, int64(15), result)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment