Skip to content

Instantly share code, notes, and snippets.

@marknuzz
Created October 7, 2024 07:11
Show Gist options
  • Save marknuzz/c7598efa664fe944d790e16464280cd1 to your computer and use it in GitHub Desktop.
Save marknuzz/c7598efa664fe944d790e16464280cd1 to your computer and use it in GitHub Desktop.
rspec-sorbet compiler, work in progress
# typed: true
# frozen_string_literal: true
# WORK IN PROGRESS
# Authors: Hongli Lai, Mark Nuzzolilo
# https://github.com/FooBarWidget/sorbet-rspec
# Put this in rspec test
# FakeSig.sig(self) { returns(Integer) }
# let(:some_var) { 1 }
module FakeSig
T::Sig::WithoutRuntime.sig { params(ctx: Module, blk: T.proc.bind(T::Private::Methods::DeclBuilder).void).void }
def self.sig(ctx, &blk)
caller_location = T.must(T.must(Kernel.caller_locations(1, 1)).first)
calling_class = ctx
unless calling_class.instance_variable_defined?(:@__fakesig_blocks)
calling_class.instance_variable_set(:@__fakesig_blocks,
{})
end
fakesig_blocks = T.cast(calling_class.instance_variable_get(:@__fakesig_blocks),
T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]])
path = T.must(caller_location.absolute_path)
fakesig_blocks[path] ||= {}
procs = T.cast(fakesig_blocks.fetch(path), T::Hash[Integer, [Thread::Backtrace::Location, Proc]])
lineno = caller_location.lineno
procs[lineno] = [caller_location, blk]
end
end
require 'rspec/core'
require 'parser/current'
module Tapioca
module Dsl
module Compilers
class RSpecCustom < Compiler
extend T::Sig
ConstantType = type_member { { fixed: T.class_of(::RSpec::Core::ExampleGroup) } }
class << self
extend T::Sig
sig { override.returns(T::Enumerable[Module]) }
def gather_constants
all_classes.select { |c| c < ::RSpec::Core::ExampleGroup }
end
private
sig { void }
def require_spec_files!
Dir.glob(spec_glob).each do |file|
require(file)
# rescue Exception => e
# raise
# binding.pry
end
end
sig { returns(String) }
def spec_glob
ENV['SORBET_RSPEC_GLOB'] || File.join('.', 'spec', '**', '*.rb')
end
end
# Load all spec files during compiler definition
require_spec_files!
sig { override.void }
def decorate
klass = root.create_class(T.must(constant.name), superclass_name: T.must(constant.superclass).name)
create_includes(klass)
create_example_group_submodules(klass)
create_singleton_methods(klass)
end
private
sig { params(klass: RBI::Scope).void }
def create_includes(klass)
directly_included_modules_for(constant).each do |mod|
klass.create_include("::#{mod}")
end
end
module StructExtension
abstract!
def prop(name, type, **options); end
def const(name, type, **options); end
end
class SemanticNode < T::InexactStruct
const :ast, Parser::AST::Node, inspect: ->(ast) { ast.class.to_s }
const :children, T::Array[SemanticNode],
factory: -> { T.let([], T::Array[SemanticNode]) },
inspect: ->(v) { v.map { |c| "\n #{c.pp_label}" }.join('') }
const :parent, T.nilable(SemanticNode), inspect: ->(n) { n&.pp_label }
# For each child node in ast, if we have a corresponding semantic node
# then the semantic node will exist in the corresponding array element
const :ast_map, T::Array[T.nilable(SemanticNode)],
factory: -> { [] },
inspect: ->(a) { a.map { |n| n.nil? ? '.' : '+' }.join('') }
sig { overridable.returns(String) }
def pp_label = "@'#{name}'"
sig { params(idx: Integer).returns(SemanticNode) }
def [](idx) = children.fetch(idx)
sig { returns(Integer) }
def count = children.count
sig { returns(T::Array[SemanticNode]) }
def flatten = [self] + children.flat_map(&:flatten)
sig { returns(T::Array[SendNode]) }
def sends = flatten.grep(SendNode)
sig { returns(T::Array[SendNode]) }
def sends_no_receivers = sends.reject(&:receiver?)
sig { returns(T::Array[SendNode]) }
def fake_sigs
sends.filter do |x|
x.method_name == :sig &&
x.receiver.is_a?(ConstNode) &&
T.cast(x.receiver, ConstNode).const_name == :FakeSig
end
end
def self.divine_type(ast)
@unknown_types ||= T.let(Set[], T.nilable(T::Set[Symbol]))
unknown_types = T.must(@unknown_types)
case ast.type
when :args then ArgsNode
when :block then BlockNode
when :array then ArrayNode
when :begin then BeginNode
when :block_pass then BlockPassNode
when :const then ConstNode
when :false then FalseNode # rubocop:disable Lint/BooleanSymbol
when :hash then HashNode
when :int then IntNode
when :lvar then LvarNode
when :lvasgn then LvasgnNode
when :nil then NilNode
when :pair then PairNode
when :self then SelfNode
when :send then SendNode
when :str then StrNode
when :sym then SymNode
when :true then TrueNode # rubocop:disable Lint/BooleanSymbol
else
unknown_types.add(ast.type) unless unknown_types.include?(ast.type)
SemanticNode
end
end
# region Nodes
class ArgsNode < SemanticNode; end
class BlockNode < SemanticNode; end
class ArrayNode < SemanticNode; end
class BeginNode < SemanticNode; end
class BlockPassNode < SemanticNode; end
class ConstNode < SemanticNode
const :const_name, T.nilable(Symbol)
def initialize(*args, ast:, **kwargs)
kwargs[:const_name] = ast.children&.[](1)
super
end
sig { override.returns(String) }
def pp_label = (const_name && const_name.to_s) || super
end
class FalseNode < SemanticNode; end
class HashNode < SemanticNode; end
class IntNode < SemanticNode; end
class LvarNode < SemanticNode; end
class LvasgnNode < SemanticNode; end
class NilNode < SemanticNode; end
class PairNode < SemanticNode; end
class SelfNode < SemanticNode; end
class SendNode < SemanticNode
const :method_name, Symbol
prop :receiver, T.nilable(SemanticNode), inspect: ->(n) { n&.pp_label }
sig { returns(T::Boolean) }
def receiver? = receiver.present?
def fake_sig?
method_name == :sig && receiver.is_a?(ConstNode) && T.cast(receiver, ConstNode).const_name == :FakeSig
end
sig { override.returns(String) }
def pp_label = ":#{method_name}"
def initialize(*args, ast:, **kwargs)
kwargs[:method_name] = ast.children&.[](1) || :__UNKNOWN__
super
end
sig { override.params(ast: Parser::AST::Node, parent: T.nilable(SemanticNode)).returns(SendNode) }
def self.make(ast:, parent: nil)
parent = T.must(parent)
this = T.cast(super, SendNode)
this.receiver = this.ast_map[0]
this
end
end
class StrNode < SemanticNode; end
class SymNode < SemanticNode; end
class TrueNode < SemanticNode; end
# endregion
sig { overridable.params(ast: Parser::AST::Node, parent: T.nilable(SemanticNode)).returns(SemanticNode) }
def self.make(ast:, parent: nil)
klass = divine_type(ast)
this = klass.new(ast:, parent:)
(ast.children || []).each do |node|
unless node.is_a?(Parser::AST::Node)
this.ast_map.push(nil)
next
end
c_klass = divine_type(node)
child = c_klass.make(ast: node, parent: this)
this.children.append(child)
this.ast_map.push(child)
end
this
end
sig { overridable.returns(String) }
def name = "#{self.class.name}"
end
class MethodSemantics < T::Struct
# const :parent_file, FileSemantics
const :method_name, Symbol
const :ast, Parser::AST::Node
end
class FileSemantics < T::Struct
include ::AST::Sexp
const :mod, Module
const :path, String
const :root, SemanticNode
const :method_source_nodes, T::Hash[Symbol, SemanticNode], factory: -> { {} }
const :methods_by_line, T::Hash[Integer, MethodSemantics], factory: -> { {} }
sig { params(unbound_method: UnboundMethod).returns(MethodSemantics) }
def method_semantics(unbound_method)
line = T.must(unbound_method.source_location)[1]
methods_by_line.fetch(line) if methods_by_line.key?(line)
raise NotImplementedError
end
class << self
sig { params(mod: Module, path: String, ast: Parser::AST::Node).returns(T.attached_class) }
def make(mod:, path:, ast:)
root = SemanticNode.make(ast:)
new(mod:, path:, root:)
end
sig { params(ast: Parser::AST::Node).returns(T::Array[Symbol]) }
def node_children_types(ast) = ast.children&.grep(Parser::AST::Node)&.map(&:type) || []
sig {
params(ast: Parser::AST::Node, type: Symbol,
state: T.nilable(T::Array[Parser::AST::Node])).returns(T::Array[Parser::AST::Node])
}
def deep_find_nodes(ast, type, state = nil)
state ||= T.let([], T::Array[Parser::AST::Node])
state.push(ast) if ast.type == type
(ast.children || []).grep(Parser::AST::Node).reduce(state) { |s, n| deep_find_nodes(n, type, s) }
state
end
sig { params(ast: Parser::AST::Node, raw_method_name: Symbol).returns(T::Array[Parser::AST::Node]) }
def find_raw_method_calls(ast, raw_method_name)
sends = deep_find_nodes(ast, :send)
end
end
end
class ModuleSemantics < T::Struct
SemanticsCache = T.type_alias { T::Hash[String, FileSemantics] }
const :mod, Module
const :semantics, SemanticsCache, factory: -> { {} }
sig { params(unbound_method: UnboundMethod).returns(FileSemantics) }
def file_semantics(unbound_method)
path, = T.must(unbound_method.source_location)
return semantics.fetch(path) if semantics.key?(path)
buffer = Parser::Source::Buffer.new(path, source: File.read(path))
ast = T.let(Parser::CurrentRuby.new.parse(buffer), Parser::AST::Node)
res = semantics[path] = FileSemantics.make(mod:, path:, ast:)
if proc_fake_sigs.count.positive?
node_fake_sigs = res.root.fake_sigs
caller_location, decl_proc = T.must(T.must(proc_fake_sigs.values.first).values[0])
decl_block = T::Private::Methods::DeclarationBlock.new(nil, caller_location, decl_proc, false, true)
builder = T.unsafe(T::Private::Methods::DeclBuilder).new(decl_block.mod, decl_block.raw)
builder.instance_exec(&decl_block.blk)
builder.finalize!
decl = builder.decl
signature = '::T::Private::Methods::Signature'.constantize.new(
method: unbound_method,
method_name: unbound_method.name,
raw_arg_types: {},
raw_return_type: decl.returns,
bind: nil,
mode: 'standard',
check_level: :never,
on_failure: nil,
override_allow_incompatible: nil,
defined_raw: true
)
binding.pry
end
res
end
sig { returns(T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]]) }
def proc_fake_sigs
return {} unless mod.instance_variable_defined?(:@__fakesig_blocks)
T.cast(mod.instance_variable_get(:@__fakesig_blocks),
T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]])
end
sig { params(unbound_method: UnboundMethod).returns(MethodSemantics) }
def method_semantics(unbound_method) = file_semantics(unbound_method).method_semantics(unbound_method)
end
class << self
sig { returns(CompilationWorkspace) }
attr_accessor :workspace
end
def initialize(*args, **kwargs)
super
T.unsafe(self).class.workspace ||= CompilationWorkspace.new
end
class CompilationWorkspace < T::Struct
SemanticsCache = T.type_alias { T::Hash[T.class_of(::RSpec::Core::ExampleGroup), ModuleSemantics] }
const :semantics, SemanticsCache, factory: -> { {} }
sig { params(mod: T.class_of(::RSpec::Core::ExampleGroup)).returns(ModuleSemantics) }
def module_semantics(mod)
semantics[mod] ||= ModuleSemantics.new(mod:)
end
sig {
params(mod: T.class_of(::RSpec::Core::ExampleGroup),
unbound_method: UnboundMethod).returns(MethodSemantics)
}
def method_semantics(mod, unbound_method) = module_semantics(mod).method_semantics(unbound_method)
end
sig { params(_klass: RBI::Scope).void }
def create_example_group_submodules(_klass)
modules = directly_included_modules_for(constant).select { |mod|
mod.name&.start_with?('RSpec::ExampleGroups::')
}
modules.each do |mod|
scope = root.create_module(T.must(mod.name))
let_defs_module = ::RSpec::Core::MemoizedHelpers.module_for(constant)
methods_types = T.let({}, T::Hash[Symbol, String])
return_value = T.unsafe(nil)
direct_public_instance_methods_for(mod).each do |method_name|
method_def = mod.instance_method(method_name)
file_semantics = self.class.workspace.method_semantics(constant, method_def)
scope.create_method(
method_def.name.to_s,
parameters: compile_method_parameters_to_rbi(method_def),
return_type: methods_types[method_name] || 'T.untyped',
class_method: false
)
end
end
end
sig { params(klass: RBI::Scope).void }
def create_singleton_methods(klass)
scope = klass.create_class('<< self')
scope.create_method(
'let',
parameters: [
create_rest_param('name', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'subject',
parameters: [
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'let!',
parameters: [
create_rest_param('name', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'before',
parameters: [
create_rest_param('args', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'after',
parameters: [
create_rest_param('args', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'it',
parameters: [
create_rest_param('all_args', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
scope.create_method(
'specify',
parameters: [
create_rest_param('all_args', type: 'T.untyped'),
create_block_param('block', type: "T.proc.bind(#{constant.name}).void")
]
)
singleton_class = constant.singleton_class
direct_public_instance_methods_for(singleton_class).each do |method_name|
create_method_from_def(scope, singleton_class.instance_method(method_name))
end
end
sig { params(constant: Module).returns(T::Enumerable[Module]) }
def directly_included_modules_for(constant)
result = constant.included_modules
result -= constant.included_modules.map do |included_mod|
included_mod.ancestors - [included_mod]
end.flatten
result -= T.must(constant.superclass).included_modules if constant.is_a?(Class) && constant.superclass
result
end
sig { params(constant: Module).returns(T::Enumerable[Symbol]) }
def direct_public_instance_methods_for(constant)
result = constant.public_instance_methods
constant.included_modules.each do |included_mod|
result -= included_mod.public_instance_methods
end
result -= T.must(constant.superclass).public_instance_methods if constant.is_a?(Class) && constant.superclass
result
end
end
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment