Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Fix for recursive closures that are defined in other functions. These break in version 1.0 of "Writing A Compiler In Go". This fix adds another opcode, OpGetSelf, and emits it whenever there's a reference to the currently executed function.
diff --git a/ast/ast.go b/ast/ast.go
index 8db3b39..f0420f4 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -2,6 +2,7 @@ package ast
import (
"bytes"
+ "fmt"
"strings"
"github.com/mrnugget/monkey/token"
@@ -223,6 +224,7 @@ type FunctionLiteral struct {
Token token.Token // The 'fn' token
Parameters []*Identifier
Body *BlockStatement
+ Name string
}
func (fl *FunctionLiteral) expressionNode() {}
@@ -236,6 +238,9 @@ func (fl *FunctionLiteral) String() string {
}
out.WriteString(fl.TokenLiteral())
+ if fl.Name != "" {
+ out.WriteString(fmt.Sprintf("<%s>", fl.Name))
+ }
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
diff --git a/code/code.go b/code/code.go
index 8931973..443482b 100644
--- a/code/code.go
+++ b/code/code.go
@@ -36,6 +36,7 @@ const (
OpSetLocal // [idx] -- Store local variable
OpGetFree // [idx] -- Load a free variable from the current closure's Free store
OpGetBuiltin // [idx] -- Load a free variable from the current closure's Free store
+ OpGetSelf // [] -- Load the current closure onto the stack
OpCall // [n] -- Call function that sits on top of stack with n arguments
OpReturnValue // [] -- Returns from the function -- Value sits on stack
@@ -90,6 +91,7 @@ var definitions = map[Opcode]*Definition{
OpSetLocal: {"OpSetLocal", []int{1}},
OpGetFree: {"OpGetFree", []int{1}},
OpGetBuiltin: {"OpGetBuiltin", []int{1}},
+ OpGetSelf: {"OpGetSelf", []int{}},
// [idx, lenFree, numLocals] -- Turn the CONSTANT at [idx] into a Closure and put on stack
OpClosure: {"OpClosure", []int{2, 1}},
diff --git a/compiler/compiler.go b/compiler/compiler.go
index 478147d..0021f31 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -254,6 +254,10 @@ func (c *Compiler) Compile(node ast.Node) error {
case *ast.FunctionLiteral:
c.enterScope()
+ if node.Name != "" {
+ c.symbolTable.DefineSelf(node.Name)
+ }
+
for _, p := range node.Parameters {
c.symbolTable.Define(p.Value)
}
@@ -431,6 +435,8 @@ func (c *Compiler) loadSymbol(s Symbol) {
c.emit(code.OpGetBuiltin, s.Index)
case FreeScope:
c.emit(code.OpGetFree, s.Index)
+ case SelfScope:
+ c.emit(code.OpGetSelf)
}
}
diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go
index 45b168a..c3c4edb 100644
--- a/compiler/compiler_test.go
+++ b/compiler/compiler_test.go
@@ -395,6 +395,106 @@ func TestFunctions(t *testing.T) {
runCompilerTests(t, tests)
}
+func TestRecursiveFunctions(t *testing.T) {
+ tests := []compilerTestCase{
+ {
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ `,
+ expectedConstants: []interface{}{
+ 0,
+ 0,
+ 1,
+ []code.Instructions{
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 0),
+ code.Make(code.OpEqual),
+ code.Make(code.OpJumpNotTruthy, 16),
+ code.Make(code.OpConstant, 1),
+ code.Make(code.OpReturnValue),
+ code.Make(code.OpJump, 25),
+ code.Make(code.OpGetSelf),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 2),
+ code.Make(code.OpSub),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ 1,
+ },
+ expectedInstructions: []code.Instructions{
+ code.Make(code.OpClosure, 3, 0),
+ code.Make(code.OpSetGlobal, 0),
+ code.Make(code.OpGetGlobal, 0),
+ code.Make(code.OpConstant, 4),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpPop),
+ },
+ },
+ {
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expectedConstants: []interface{}{
+ 0,
+ 0,
+ 1,
+ []code.Instructions{
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 0),
+ code.Make(code.OpEqual),
+ code.Make(code.OpJumpNotTruthy, 16),
+ code.Make(code.OpConstant, 1),
+ code.Make(code.OpReturnValue),
+ code.Make(code.OpJump, 25),
+ code.Make(code.OpGetSelf),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 2),
+ code.Make(code.OpSub),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ 1,
+ []code.Instructions{
+
+ code.Make(code.OpClosure, 3, 0),
+ code.Make(code.OpSetLocal, 0),
+ code.Make(code.OpGetLocal, 0),
+ code.Make(code.OpConstant, 4),
+ code.Make(code.OpCall, 1),
+ code.Make(code.OpReturnValue),
+ },
+ },
+ expectedInstructions: []code.Instructions{
+ code.Make(code.OpClosure, 5, 0),
+ code.Make(code.OpSetGlobal, 0),
+ code.Make(code.OpGetGlobal, 0),
+ code.Make(code.OpCall, 0),
+ code.Make(code.OpPop),
+ },
+ },
+ }
+
+ runCompilerTests(t, tests)
+}
+
type compilerTestCase struct {
input string
expectedConstants []interface{}
diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go
index 70d0280..0d6b4eb 100644
--- a/compiler/symbol_table.go
+++ b/compiler/symbol_table.go
@@ -7,6 +7,7 @@ const (
GlobalScope SymbolScope = "GLOBAL"
BuiltinScope SymbolScope = "BUILTIN"
FreeScope SymbolScope = "FREE"
+ SelfScope SymbolScope = "SELF"
)
type Symbol struct {
@@ -73,6 +74,12 @@ func (s *SymbolTable) DefineBuiltin(index int, name string) Symbol {
return symbol
}
+func (s *SymbolTable) DefineSelf(name string) Symbol {
+ symbol := Symbol{Name: name, Index: 0, Scope: SelfScope}
+ s.store[name] = symbol
+ return symbol
+}
+
func (s *SymbolTable) defineFree(original Symbol) Symbol {
s.FreeSymbols = append(s.FreeSymbols, original)
diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go
index 8ceac0e..3b0315e 100644
--- a/compiler/symbol_table_test.go
+++ b/compiler/symbol_table_test.go
@@ -300,3 +300,38 @@ func TestResolveUnresolvableFree(t *testing.T) {
}
}
}
+
+func TestDefineAndResolveSelf(t *testing.T) {
+ expected := Symbol{Name: "a", Scope: SelfScope, Index: 0}
+
+ global := NewSymbolTable()
+ global.DefineSelf("a")
+
+ result, ok := global.Resolve(expected.Name)
+ if !ok {
+ t.Fatalf("self name %s not resolvable", expected.Name)
+ }
+
+ if result != expected {
+ t.Errorf("expected %s to resolve to %+v, got=%+v",
+ expected.Name, expected, result)
+ }
+}
+
+func TestShadowingSelf(t *testing.T) {
+ expected := Symbol{Name: "a", Scope: GlobalScope, Index: 0}
+
+ global := NewSymbolTable()
+ global.DefineSelf(expected.Name)
+ global.Define(expected.Name)
+
+ result, ok := global.Resolve(expected.Name)
+ if !ok {
+ t.Fatalf("self name %s not resolvable", expected.Name)
+ }
+
+ if result != expected {
+ t.Errorf("expected %s to resolve to %+v, got=%+v",
+ expected.Name, expected, result)
+ }
+}
diff --git a/parser/parser.go b/parser/parser.go
index 94635b9..dbd581d 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -172,6 +172,10 @@ func (p *Parser) parseLetStatement() *ast.LetStatement {
stmt.Value = p.parseExpression(LOWEST)
+ if fl, ok := stmt.Value.(*ast.FunctionLiteral); ok {
+ fl.Name = stmt.Name.Value
+ }
+
if p.peekTokenIs(token.SEMICOLON) {
p.nextToken()
}
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 9fccfce..98229ac 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -587,6 +587,37 @@ func TestFunctionParameterParsing(t *testing.T) {
}
}
+func TestFunctionDefinitionParsing(t *testing.T) {
+ input := `let myFunction = fn() { };`
+
+ l := lexer.New(input)
+ p := New(l)
+ program := p.ParseProgram()
+ checkParserErrors(t, p)
+
+ if len(program.Statements) != 1 {
+ t.Fatalf("program.Body does not contain %d statements. got=%d\n",
+ 1, len(program.Statements))
+ }
+
+ stmt, ok := program.Statements[0].(*ast.LetStatement)
+ if !ok {
+ t.Fatalf("program.Statements[0] is not ast.LetStatement. got=%T",
+ program.Statements[0])
+ }
+
+ function, ok := stmt.Value.(*ast.FunctionLiteral)
+ if !ok {
+ t.Fatalf("stmt.Value is not ast.FunctionLiteral. got=%T",
+ stmt.Value)
+ }
+
+ if function.Name != "myFunction" {
+ t.Fatalf("function literal name wrong. want 'myFunction', got=%q\n",
+ function.Name)
+ }
+}
+
func TestCallExpressionParsing(t *testing.T) {
input := "add(1, 2 * 3, 4 + 5);"
diff --git a/vm/vm.go b/vm/vm.go
index 31b1699..8490912 100644
--- a/vm/vm.go
+++ b/vm/vm.go
@@ -278,6 +278,13 @@ func (vm *VM) Run() error {
if err != nil {
return err
}
+
+ case code.OpGetSelf:
+ currentClosure := vm.currentFrame().cl
+ err := vm.push(currentClosure)
+ if err != nil {
+ return err
+ }
}
if vm.trace {
diff --git a/vm/vm_test.go b/vm/vm_test.go
index aace5d6..10bcd3e 100644
--- a/vm/vm_test.go
+++ b/vm/vm_test.go
@@ -564,6 +564,83 @@ func TestRecursiveFibonacci(t *testing.T) {
runVmTests(t, tests)
}
+func TestRecursiveFunctions(t *testing.T) {
+ tests := []vmTestCase{
+ {
+ // This works
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ `,
+ expected: 0,
+ },
+ {
+ // This also works
+ input: `
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ let wrapper = fn() {
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 0,
+ },
+ {
+ // This did _NOT_ work
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ if (x == 0) {
+ return 0;
+ } else {
+ inner(x - 1);
+ }
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 0,
+ },
+ {
+ // Test that shadowing still works
+ input: `
+ let one = fn() { let one = 1; return one };
+ one();
+ `,
+ expected: 1,
+ },
+ {
+ // Test that shadowing still works
+ input: `
+ let wrapper = fn() {
+ let inner = fn(x) {
+ let inner = 2;
+ x + inner
+ };
+ inner(1);
+ };
+ wrapper();
+ `,
+ expected: 3,
+ },
+ }
+
+ runVmTests(t, tests)
+}
+
type vmTestCase struct {
input string
expected interface{}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.