Skip to content

Instantly share code, notes, and snippets.

@ozanh
Last active May 12, 2020 02:40
Show Gist options
  • Save ozanh/f81b2eb7aff4b14cf871d8d29c9e60e7 to your computer and use it in GitHub Desktop.
Save ozanh/f81b2eb7aff4b14cf871d8d29c9e60e7 to your computer and use it in GitHub Desktop.
Call Tengo CompiledFunction
package main
import (
"errors"
"fmt"
"sync"
"github.com/d5/tengo/v2"
"github.com/d5/tengo/v2/parser"
)
// define functions
const mathFunctions = `
x := 0
mul := func(a, b) {
return a * b
}
square := func(a) {
return mul(a, a)
}
add := func(a, b) {
return a + b
}
test := func(a, b, ...c) {
x++
return mul(add(add(a, b), c[0]), x)
}
`
func main() {
s := NewScript()
if err := s.Compile([]byte(mathFunctions)); err != nil {
panic(err)
}
v, err := s.Call("add", 1, 2)
if err != nil {
panic(err)
}
fmt.Println(v)
v, err = s.Call("square", 11)
if err != nil {
panic(err)
}
fmt.Println(v)
v, err = s.Call("mul", 6, 6)
if err != nil {
panic(err)
}
fmt.Println(v)
// use x global var
v, err = s.Call("test", 1, 2, 3) // (1 + 2 + 3) * 1
if err != nil {
panic(err)
}
fmt.Println(v)
// x will be 2
v, err = s.Call("test", 1, 2, 3) // (1 + 2 + 3) * 2
if err != nil {
panic(err)
}
fmt.Println(v)
}
type Script struct {
mu sync.Mutex
bytecode *tengo.Bytecode
symbols *tengo.SymbolTable
indexes map[string]int
globals []tengo.Object
outIdx int
}
func NewScript() *Script {
return &Script{}
}
func (s *Script) Compile(src []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
fileSet := parser.NewFileSet()
srcFile := fileSet.AddFile("(main)", -1, len(src))
p := parser.NewParser(srcFile, src, nil)
file, err := p.ParseFile()
if err != nil {
return err
}
symbolTable := tengo.NewSymbolTable()
out := symbolTable.Define("$out")
s.globals = make([]tengo.Object, tengo.GlobalsSize)
s.globals[out.Index] = tengo.UndefinedValue
cc := tengo.NewCompiler(srcFile, symbolTable, nil, nil, nil)
if err := cc.Compile(file); err != nil {
return err
}
s.globals = s.globals[:symbolTable.MaxSymbols()+1]
globalIndexes := make(map[string]int, len(s.globals))
for _, name := range symbolTable.Names() {
symbol, _, _ := symbolTable.Resolve(name)
if symbol.Scope == tengo.ScopeGlobal {
globalIndexes[name] = symbol.Index
}
}
vm := tengo.NewVM(cc.Bytecode(), s.globals, -1)
// fill globals, assume script only has function definitions
if err := vm.Run(); err != nil {
return err
}
bc := cc.Bytecode()
bc.RemoveDuplicates()
s.bytecode = bc
s.symbols = symbolTable
s.outIdx = out.Index
s.indexes = globalIndexes
return nil
}
func (s *Script) Call(fn string,
args ...interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
// compiled function
idx, ok := s.indexes[fn]
if !ok {
return nil, errors.New("not found")
}
cfn, ok := s.globals[idx].(*tengo.CompiledFunction)
if !ok {
return nil, errors.New("not a compiled function")
}
targs := make([]tengo.Object, 0, len(args))
for i := range args {
v, err := tengo.FromInterface(args[i])
if err != nil {
return nil, err
}
targs = append(targs, v)
}
v, err := s.callCompiled(cfn, targs...)
if err != nil {
return nil, err
}
return tengo.ToInterface(v), nil
}
func (s *Script) callCompiled(fn *tengo.CompiledFunction,
args ...tengo.Object) (tengo.Object, error) {
constsOffset := len(s.bytecode.Constants)
// Load fn
inst := tengo.MakeInstruction(parser.OpConstant, constsOffset)
// Load args
for i := range args {
inst = append(inst,
tengo.MakeInstruction(parser.OpConstant, constsOffset+i+1)...)
}
// Call, set value to a global, stop
inst = append(inst, tengo.MakeInstruction(parser.OpCall, len(args))...)
inst = append(inst, tengo.MakeInstruction(parser.OpSetGlobal, s.outIdx)...)
inst = append(inst, tengo.MakeInstruction(parser.OpSuspend)...)
s.bytecode.Constants = append(s.bytecode.Constants, fn)
s.bytecode.Constants = append(s.bytecode.Constants, args...)
orig := s.bytecode.MainFunction
s.bytecode.MainFunction = &tengo.CompiledFunction{
Instructions: inst,
}
vm := tengo.NewVM(s.bytecode, s.globals, -1)
err := vm.Run()
// go back to normal if required
s.bytecode.MainFunction = orig
s.bytecode.Constants = s.bytecode.Constants[:constsOffset]
// get symbol using index and return it
return s.globals[s.outIdx], err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment