Skip to content

Instantly share code, notes, and snippets.

@mrnugget
Last active March 20, 2019 10:30
Show Gist options
  • Save mrnugget/a2bc0794b7d1a249de77a19ea0807389 to your computer and use it in GitHub Desktop.
Save mrnugget/a2bc0794b7d1a249de77a19ea0807389 to your computer and use it in GitHub Desktop.
Fix for recursive closures that are defined in other functions. These break in version 1.0 of "Writing A Compiler In Go". This fix adds another opcode, OpGetSelf, and emits it whenever there's a reference to the currently executed function.
diff --git a/ast/ast.go b/ast/ast.go
index 8db3b39..f0420f4 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -2,6 +2,7 @@ package ast
import (
"bytes"
+ "fmt"
"strings"
"github.com/mrnugget/monkey/token"
@@ -223,6 +224,7 @@ type FunctionLiteral struct {
Token token.Token // The 'fn' token
Parameters []*Identifier
Body *BlockStatement
+ Name string
}
func (fl *FunctionLiteral) expressionNode() {}
@@ -236,6 +238,9 @@ func (fl *FunctionLiteral) String() string {
}
out.WriteString(fl.TokenLiteral())
+ if fl.Name != "" {
+ out.WriteString(fmt.Sprintf("<%s>", fl.Name))
+ }
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
diff --git a/code/code.go b/code/code.go
index 8931973..443482b 100644
--- a/code/code.go
+++ b/code/code.go
@@ -36,6 +36,7 @@ const (
OpSetLocal // [idx] -- Store local variable
OpGetFree // [idx] -- Load a free variable from the current closure's Free store
OpGetBuiltin // [idx] -- Load a free variable from the current closure's Free store
+ OpGetSelf // [] -- Load the current closure onto the stack
OpCall // [n] -- Call function that sits on top of stack with n arguments
OpReturnValue // [] -- Returns from the function -- Value sits on stack
@@ -90,6 +91,7 @@ var definitions = map[Opcode]*Definition{
OpSetLocal: {"OpSetLocal", []int{1}},
OpGetFree: {"OpGetFree", []int{1}},
OpGetBuiltin: {"OpGetBuiltin", []int{1}},
+ OpGetSelf: {"OpGetSelf", []int{}},
// [idx, lenFree, numLocals] -- Turn the CONSTANT at [idx] into a Closure and put on stack
OpClosure: {"OpClosure", []int{2, 1}},
diff --git a/compiler/compiler.go b/compiler/compiler.go
index 478147d..0021f31 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -254,6 +254,10 @@ func (c *Compiler) Compile(node ast.Node) error {
case *ast.FunctionLiteral:
c.enterScope()
+ if node.Name != "" {
+ c.symbolTable.DefineSelf(node.Name)
+ }
+
for _, p := range node.Parameters {
c.symbolTable.Define(p.Value)
}
@@ -431,6 +435,8 @@ func (c *Compiler) loadSymbol(s Symbol) {
c.emit(code.OpGetBuiltin, s.Index)
case FreeScope:
c.emit(code.OpGetFree, s.Index)
+ case SelfScope:
+ c.emit(code.OpGetSelf)
}
}
diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go
index 45b168a..c3c4edb 100644
--- a/compiler/compiler_test.go
+++ b/compiler/compiler_test.go
@@ -395,6 +395,106 @@ func TestFunctions(t *testing.T) {
runCompilerTests(t, tests)
}
+func TestRecursiveFunctions(t *testing.T) {
+ tests := []compilerTestCase{
+ {
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ `,
+ expectedConstants: []interface{}{
+ 0,
+ 0,
+ 1,
+ []code.Instructions{
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 0),
+ code.Make(code.OpEqual),
+ code.Make(code.OpJumpNotTruthy, 16),
+ code.Make(code.OpConstant, 1),
+ code.Make(code.OpReturnValue),
+ code.Make(code.OpJump, 25),
+ code.Make(code.OpGetSelf),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 2),
+ code.Make(code.OpSub),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ 1,
+ },
+ expectedInstructions: []code.Instructions{
+ code.Make(code.OpClosure, 3, 0),
+ code.Make(code.OpSetGlobal, 0),
+ code.Make(code.OpGetGlobal, 0),
+ code.Make(code.OpConstant, 4),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpPop),
+ },
+ },
+ {
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expectedConstants: []interface{}{
+ 0,
+ 0,
+ 1,
+ []code.Instructions{
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 0),
+ code.Make(code.OpEqual),
+ code.Make(code.OpJumpNotTruthy, 16),
+ code.Make(code.OpConstant, 1),
+ code.Make(code.OpReturnValue),
+ code.Make(code.OpJump, 25),
+ code.Make(code.OpGetSelf),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 2),
+ code.Make(code.OpSub),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ 1,
+ []code.Instructions{
+
+ code.Make(code.OpClosure, 3, 0),
+ code.Make(code.OpSetLocal, 0),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 4),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ },
+ expectedInstructions: []code.Instructions{
+ code.Make(code.OpClosure, 5, 0),
+ code.Make(code.OpSetGlobal, 0),
+ code.Make(code.OpGetGlobal, 0),
+ code.Make(code.OpCall, 0),
+ code.Make(code.OpPop),
+ },
+ },
+ }
+
+ runCompilerTests(t, tests)
+}
+
type compilerTestCase struct {
input string
expectedConstants []interface{}
diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go
index 70d0280..0d6b4eb 100644
--- a/compiler/symbol_table.go
+++ b/compiler/symbol_table.go
@@ -7,6 +7,7 @@ const (
GlobalScope SymbolScope = "GLOBAL"
BuiltinScope SymbolScope = "BUILTIN"
FreeScope SymbolScope = "FREE"
+ SelfScope SymbolScope = "SELF"
)
type Symbol struct {
@@ -73,6 +74,12 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
return symbol
}
+func (s *SymbolTable) DefineSelf(name string) Symbol {
+ symbol := Symbol{Name: name, Index: 0, Scope: SelfScope}
+ s.store[name] = symbol
+ return symbol
+}
+
func (s *SymbolTable) defineFree(original Symbol) Symbol {
s.FreeSymbols = append(s.FreeSymbols, original)
diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go
index 8ceac0e..3b0315e 100644
--- a/compiler/symbol_table_test.go
+++ b/compiler/symbol_table_test.go
@@ -300,3 +300,38 @@ func TestResolveUnresolvableFree(t *testing.T) {
}
}
}
+
+func TestDefineAndResolveSelf(t *testing.T) {
+ expected := Symbol{Name: "a", Scope: SelfScope, Index: 0}
+
+ global := NewSymbolTable()
+ global.DefineSelf("a")
+
+ result, ok := global.Resolve(expected.Name)
+ if !ok {
+ t.Fatalf("self name %s not resolvable", expected.Name)
+ }
+
+ if result != expected {
+ t.Errorf("expected %s to resolve to %+v, got=%+v",
+ expected.Name, expected, result)
+ }
+}
+
+func TestShadowingSelf(t *testing.T) {
+ expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
+
+ global := NewSymbolTable()
+ global.DefineSelf(expected.Name)
+ global.Define(expected.Name)
+
+ result, ok := global.Resolve(expected.Name)
+ if !ok {
+ t.Fatalf("self name %s not resolvable", expected.Name)
+ }
+
+ if result != expected {
+ t.Errorf("expected %s to resolve to %+v, got=%+v",
+ expected.Name, expected, result)
+ }
+}
diff --git a/parser/parser.go b/parser/parser.go
index 94635b9..dbd581d 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -172,6 +172,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement {
stmt.Value = p.parseExpression(LOWEST)
+ if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok {
+ fl.Name = stmt.Name.Value
+ }
+
if p.peekTokenIs(token.SEMICOLON) {
p.nextToken()
}
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 9fccfce..98229ac 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -587,6 +587,37 @@ func TestFunctionParameterParsing(t *testing.T) {
}
}
+func TestFunctionDefinitionParsing(t *testing.T) {
+ input := `let myFunction = fn() { };`
+
+ l := lexer.New(input)
+ p := New(l)
+ program := p.ParseProgram()
+ checkParserErrors(t, p)
+
+ if len(program.Statements) != 1 {
+ t.Fatalf("program.Body does not contain %d statements. got=%d\n",
+ 1, len(program.Statements))
+ }
+
+ stmt, ok := program.Statements[0].(*ast.LetStatement)
+ if !ok {
+ t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T",
+ program.Statements[0])
+ }
+
+ function, ok := stmt.Value.(*ast.FunctionLiteral)
+ if !ok {
+ t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T",
+ stmt.Value)
+ }
+
+ if function.Name != "myFunction" {
+ t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n",
+ function.Name)
+ }
+}
+
func TestCallExpressionParsing(t *testing.T) {
input := "add(1, 2 * 3, 4 + 5);"
diff --git a/vm/vm.go b/vm/vm.go
index 31b1699..8490912 100644
--- a/vm/vm.go
+++ b/vm/vm.go
@@ -278,6 +278,13 @@ func (vm *VM) Run() error {
if err != nil {
return err
}
+
+ case code.OpGetSelf:
+ currentClosure := vm.currentFrame().cl
+ err := vm.push(currentClosure)
+ if err != nil {
+ return err
+ }
}
if vm.trace {
diff --git a/vm/vm_test.go b/vm/vm_test.go
index aace5d6..10bcd3e 100644
--- a/vm/vm_test.go
+++ b/vm/vm_test.go
@@ -564,6 +564,83 @@ func TestRecursiveFibonacci(t *testing.T) {
runVmTests(t, tests)
}
+func TestRecursiveFunctions(t *testing.T) {
+ tests := []vmTestCase{
+ {
+ // This works
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ `,
+ expected: 0,
+ },
+ {
+ // This also works
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ let wrapper = fn() {
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 0,
+ },
+ {
+ // This did _NOT_ work
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 0,
+ },
+ {
+ // Test that shadowing still works
+ input: `
+ let one = fn() { let one = 1; return one };
+ one();
+ `,
+ expected: 1,
+ },
+ {
+ // Test that shadowing still works
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ let inner = 2;
+ x + inner
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 3,
+ },
+ }
+
+ runVmTests(t, tests)
+}
+
type vmTestCase struct {
input string
expected interface{}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment