-
-
Save marknuzz/c7598efa664fe944d790e16464280cd1 to your computer and use it in GitHub Desktop.
rspec-sorbet compiler, work in progress
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# typed: true | |
# frozen_string_literal: true | |
# WORK IN PROGRESS | |
# Authors: Hongli Lai, Mark Nuzzolilo | |
# https://github.com/FooBarWidget/sorbet-rspec | |
# Put this in rspec test | |
# FakeSig.sig(self) { returns(Integer) } | |
# let(:some_var) { 1 } | |
module FakeSig | |
T::Sig::WithoutRuntime.sig { params(ctx: Module, blk: T.proc.bind(T::Private::Methods::DeclBuilder).void).void } | |
def self.sig(ctx, &blk) | |
caller_location = T.must(T.must(Kernel.caller_locations(1, 1)).first) | |
calling_class = ctx | |
unless calling_class.instance_variable_defined?(:@__fakesig_blocks) | |
calling_class.instance_variable_set(:@__fakesig_blocks, | |
{}) | |
end | |
fakesig_blocks = T.cast(calling_class.instance_variable_get(:@__fakesig_blocks), | |
T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]]) | |
path = T.must(caller_location.absolute_path) | |
fakesig_blocks[path] ||= {} | |
procs = T.cast(fakesig_blocks.fetch(path), T::Hash[Integer, [Thread::Backtrace::Location, Proc]]) | |
lineno = caller_location.lineno | |
procs[lineno] = [caller_location, blk] | |
end | |
end | |
require 'rspec/core' | |
require 'parser/current' | |
module Tapioca | |
module Dsl | |
module Compilers | |
class RSpecCustom < Compiler | |
extend T::Sig | |
ConstantType = type_member { { fixed: T.class_of(::RSpec::Core::ExampleGroup) } } | |
class << self | |
extend T::Sig | |
sig { override.returns(T::Enumerable[Module]) } | |
def gather_constants | |
all_classes.select { |c| c < ::RSpec::Core::ExampleGroup } | |
end | |
private | |
sig { void } | |
def require_spec_files! | |
Dir.glob(spec_glob).each do |file| | |
require(file) | |
# rescue Exception => e | |
# raise | |
# binding.pry | |
end | |
end | |
sig { returns(String) } | |
def spec_glob | |
ENV['SORBET_RSPEC_GLOB'] || File.join('.', 'spec', '**', '*.rb') | |
end | |
end | |
# Load all spec files during compiler definition | |
require_spec_files! | |
sig { override.void } | |
def decorate | |
klass = root.create_class(T.must(constant.name), superclass_name: T.must(constant.superclass).name) | |
create_includes(klass) | |
create_example_group_submodules(klass) | |
create_singleton_methods(klass) | |
end | |
private | |
sig { params(klass: RBI::Scope).void } | |
def create_includes(klass) | |
directly_included_modules_for(constant).each do |mod| | |
klass.create_include("::#{mod}") | |
end | |
end | |
module StructExtension | |
abstract! | |
def prop(name, type, **options); end | |
def const(name, type, **options); end | |
end | |
class SemanticNode < T::InexactStruct | |
const :ast, Parser::AST::Node, inspect: ->(ast) { ast.class.to_s } | |
const :children, T::Array[SemanticNode], | |
factory: -> { T.let([], T::Array[SemanticNode]) }, | |
inspect: ->(v) { v.map { |c| "\n #{c.pp_label}" }.join('') } | |
const :parent, T.nilable(SemanticNode), inspect: ->(n) { n&.pp_label } | |
# For each child node in ast, if we have a corresponding semantic node | |
# then the semantic node will exist in the corresponding array element | |
const :ast_map, T::Array[T.nilable(SemanticNode)], | |
factory: -> { [] }, | |
inspect: ->(a) { a.map { |n| n.nil? ? '.' : '+' }.join('') } | |
sig { overridable.returns(String) } | |
def pp_label = "@'#{name}'" | |
sig { params(idx: Integer).returns(SemanticNode) } | |
def [](idx) = children.fetch(idx) | |
sig { returns(Integer) } | |
def count = children.count | |
sig { returns(T::Array[SemanticNode]) } | |
def flatten = [self] + children.flat_map(&:flatten) | |
sig { returns(T::Array[SendNode]) } | |
def sends = flatten.grep(SendNode) | |
sig { returns(T::Array[SendNode]) } | |
def sends_no_receivers = sends.reject(&:receiver?) | |
sig { returns(T::Array[SendNode]) } | |
def fake_sigs | |
sends.filter do |x| | |
x.method_name == :sig && | |
x.receiver.is_a?(ConstNode) && | |
T.cast(x.receiver, ConstNode).const_name == :FakeSig | |
end | |
end | |
def self.divine_type(ast) | |
@unknown_types ||= T.let(Set[], T.nilable(T::Set[Symbol])) | |
unknown_types = T.must(@unknown_types) | |
case ast.type | |
when :args then ArgsNode | |
when :block then BlockNode | |
when :array then ArrayNode | |
when :begin then BeginNode | |
when :block_pass then BlockPassNode | |
when :const then ConstNode | |
when :false then FalseNode # rubocop:disable Lint/BooleanSymbol | |
when :hash then HashNode | |
when :int then IntNode | |
when :lvar then LvarNode | |
when :lvasgn then LvasgnNode | |
when :nil then NilNode | |
when :pair then PairNode | |
when :self then SelfNode | |
when :send then SendNode | |
when :str then StrNode | |
when :sym then SymNode | |
when :true then TrueNode # rubocop:disable Lint/BooleanSymbol | |
else | |
unknown_types.add(ast.type) unless unknown_types.include?(ast.type) | |
SemanticNode | |
end | |
end | |
# region Nodes | |
class ArgsNode < SemanticNode; end | |
class BlockNode < SemanticNode; end | |
class ArrayNode < SemanticNode; end | |
class BeginNode < SemanticNode; end | |
class BlockPassNode < SemanticNode; end | |
class ConstNode < SemanticNode | |
const :const_name, T.nilable(Symbol) | |
def initialize(*args, ast:, **kwargs) | |
kwargs[:const_name] = ast.children&.[](1) | |
super | |
end | |
sig { override.returns(String) } | |
def pp_label = (const_name && const_name.to_s) || super | |
end | |
class FalseNode < SemanticNode; end | |
class HashNode < SemanticNode; end | |
class IntNode < SemanticNode; end | |
class LvarNode < SemanticNode; end | |
class LvasgnNode < SemanticNode; end | |
class NilNode < SemanticNode; end | |
class PairNode < SemanticNode; end | |
class SelfNode < SemanticNode; end | |
class SendNode < SemanticNode | |
const :method_name, Symbol | |
prop :receiver, T.nilable(SemanticNode), inspect: ->(n) { n&.pp_label } | |
sig { returns(T::Boolean) } | |
def receiver? = receiver.present? | |
def fake_sig? | |
method_name == :sig && receiver.is_a?(ConstNode) && T.cast(receiver, ConstNode).const_name == :FakeSig | |
end | |
sig { override.returns(String) } | |
def pp_label = ":#{method_name}" | |
def initialize(*args, ast:, **kwargs) | |
kwargs[:method_name] = ast.children&.[](1) || :__UNKNOWN__ | |
super | |
end | |
sig { override.params(ast: Parser::AST::Node, parent: T.nilable(SemanticNode)).returns(SendNode) } | |
def self.make(ast:, parent: nil) | |
parent = T.must(parent) | |
this = T.cast(super, SendNode) | |
this.receiver = this.ast_map[0] | |
this | |
end | |
end | |
class StrNode < SemanticNode; end | |
class SymNode < SemanticNode; end | |
class TrueNode < SemanticNode; end | |
# endregion | |
sig { overridable.params(ast: Parser::AST::Node, parent: T.nilable(SemanticNode)).returns(SemanticNode) } | |
def self.make(ast:, parent: nil) | |
klass = divine_type(ast) | |
this = klass.new(ast:, parent:) | |
(ast.children || []).each do |node| | |
unless node.is_a?(Parser::AST::Node) | |
this.ast_map.push(nil) | |
next | |
end | |
c_klass = divine_type(node) | |
child = c_klass.make(ast: node, parent: this) | |
this.children.append(child) | |
this.ast_map.push(child) | |
end | |
this | |
end | |
sig { overridable.returns(String) } | |
def name = "#{self.class.name}" | |
end | |
class MethodSemantics < T::Struct | |
# const :parent_file, FileSemantics | |
const :method_name, Symbol | |
const :ast, Parser::AST::Node | |
end | |
class FileSemantics < T::Struct | |
include ::AST::Sexp | |
const :mod, Module | |
const :path, String | |
const :root, SemanticNode | |
const :method_source_nodes, T::Hash[Symbol, SemanticNode], factory: -> { {} } | |
const :methods_by_line, T::Hash[Integer, MethodSemantics], factory: -> { {} } | |
sig { params(unbound_method: UnboundMethod).returns(MethodSemantics) } | |
def method_semantics(unbound_method) | |
line = T.must(unbound_method.source_location)[1] | |
methods_by_line.fetch(line) if methods_by_line.key?(line) | |
raise NotImplementedError | |
end | |
class << self | |
sig { params(mod: Module, path: String, ast: Parser::AST::Node).returns(T.attached_class) } | |
def make(mod:, path:, ast:) | |
root = SemanticNode.make(ast:) | |
new(mod:, path:, root:) | |
end | |
sig { params(ast: Parser::AST::Node).returns(T::Array[Symbol]) } | |
def node_children_types(ast) = ast.children&.grep(Parser::AST::Node)&.map(&:type) || [] | |
sig { | |
params(ast: Parser::AST::Node, type: Symbol, | |
state: T.nilable(T::Array[Parser::AST::Node])).returns(T::Array[Parser::AST::Node]) | |
} | |
def deep_find_nodes(ast, type, state = nil) | |
state ||= T.let([], T::Array[Parser::AST::Node]) | |
state.push(ast) if ast.type == type | |
(ast.children || []).grep(Parser::AST::Node).reduce(state) { |s, n| deep_find_nodes(n, type, s) } | |
state | |
end | |
sig { params(ast: Parser::AST::Node, raw_method_name: Symbol).returns(T::Array[Parser::AST::Node]) } | |
def find_raw_method_calls(ast, raw_method_name) | |
sends = deep_find_nodes(ast, :send) | |
end | |
end | |
end | |
class ModuleSemantics < T::Struct | |
SemanticsCache = T.type_alias { T::Hash[String, FileSemantics] } | |
const :mod, Module | |
const :semantics, SemanticsCache, factory: -> { {} } | |
sig { params(unbound_method: UnboundMethod).returns(FileSemantics) } | |
def file_semantics(unbound_method) | |
path, = T.must(unbound_method.source_location) | |
return semantics.fetch(path) if semantics.key?(path) | |
buffer = Parser::Source::Buffer.new(path, source: File.read(path)) | |
ast = T.let(Parser::CurrentRuby.new.parse(buffer), Parser::AST::Node) | |
res = semantics[path] = FileSemantics.make(mod:, path:, ast:) | |
if proc_fake_sigs.count.positive? | |
node_fake_sigs = res.root.fake_sigs | |
caller_location, decl_proc = T.must(T.must(proc_fake_sigs.values.first).values[0]) | |
decl_block = T::Private::Methods::DeclarationBlock.new(nil, caller_location, decl_proc, false, true) | |
builder = T.unsafe(T::Private::Methods::DeclBuilder).new(decl_block.mod, decl_block.raw) | |
builder.instance_exec(&decl_block.blk) | |
builder.finalize! | |
decl = builder.decl | |
signature = '::T::Private::Methods::Signature'.constantize.new( | |
method: unbound_method, | |
method_name: unbound_method.name, | |
raw_arg_types: {}, | |
raw_return_type: decl.returns, | |
bind: nil, | |
mode: 'standard', | |
check_level: :never, | |
on_failure: nil, | |
override_allow_incompatible: nil, | |
defined_raw: true | |
) | |
binding.pry | |
end | |
res | |
end | |
sig { returns(T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]]) } | |
def proc_fake_sigs | |
return {} unless mod.instance_variable_defined?(:@__fakesig_blocks) | |
T.cast(mod.instance_variable_get(:@__fakesig_blocks), | |
T::Hash[String, T::Hash[Integer, [Thread::Backtrace::Location, Proc]]]) | |
end | |
sig { params(unbound_method: UnboundMethod).returns(MethodSemantics) } | |
def method_semantics(unbound_method) = file_semantics(unbound_method).method_semantics(unbound_method) | |
end | |
class << self | |
sig { returns(CompilationWorkspace) } | |
attr_accessor :workspace | |
end | |
def initialize(*args, **kwargs) | |
super | |
T.unsafe(self).class.workspace ||= CompilationWorkspace.new | |
end | |
class CompilationWorkspace < T::Struct | |
SemanticsCache = T.type_alias { T::Hash[T.class_of(::RSpec::Core::ExampleGroup), ModuleSemantics] } | |
const :semantics, SemanticsCache, factory: -> { {} } | |
sig { params(mod: T.class_of(::RSpec::Core::ExampleGroup)).returns(ModuleSemantics) } | |
def module_semantics(mod) | |
semantics[mod] ||= ModuleSemantics.new(mod:) | |
end | |
sig { | |
params(mod: T.class_of(::RSpec::Core::ExampleGroup), | |
unbound_method: UnboundMethod).returns(MethodSemantics) | |
} | |
def method_semantics(mod, unbound_method) = module_semantics(mod).method_semantics(unbound_method) | |
end | |
sig { params(_klass: RBI::Scope).void } | |
def create_example_group_submodules(_klass) | |
modules = directly_included_modules_for(constant).select { |mod| | |
mod.name&.start_with?('RSpec::ExampleGroups::') | |
} | |
modules.each do |mod| | |
scope = root.create_module(T.must(mod.name)) | |
let_defs_module = ::RSpec::Core::MemoizedHelpers.module_for(constant) | |
methods_types = T.let({}, T::Hash[Symbol, String]) | |
return_value = T.unsafe(nil) | |
direct_public_instance_methods_for(mod).each do |method_name| | |
method_def = mod.instance_method(method_name) | |
file_semantics = self.class.workspace.method_semantics(constant, method_def) | |
scope.create_method( | |
method_def.name.to_s, | |
parameters: compile_method_parameters_to_rbi(method_def), | |
return_type: methods_types[method_name] || 'T.untyped', | |
class_method: false | |
) | |
end | |
end | |
end | |
sig { params(klass: RBI::Scope).void } | |
def create_singleton_methods(klass) | |
scope = klass.create_class('<< self') | |
scope.create_method( | |
'let', | |
parameters: [ | |
create_rest_param('name', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'subject', | |
parameters: [ | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'let!', | |
parameters: [ | |
create_rest_param('name', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'before', | |
parameters: [ | |
create_rest_param('args', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'after', | |
parameters: [ | |
create_rest_param('args', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'it', | |
parameters: [ | |
create_rest_param('all_args', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
scope.create_method( | |
'specify', | |
parameters: [ | |
create_rest_param('all_args', type: 'T.untyped'), | |
create_block_param('block', type: "T.proc.bind(#{constant.name}).void") | |
] | |
) | |
singleton_class = constant.singleton_class | |
direct_public_instance_methods_for(singleton_class).each do |method_name| | |
create_method_from_def(scope, singleton_class.instance_method(method_name)) | |
end | |
end | |
sig { params(constant: Module).returns(T::Enumerable[Module]) } | |
def directly_included_modules_for(constant) | |
result = constant.included_modules | |
result -= constant.included_modules.map do |included_mod| | |
included_mod.ancestors - [included_mod] | |
end.flatten | |
result -= T.must(constant.superclass).included_modules if constant.is_a?(Class) && constant.superclass | |
result | |
end | |
sig { params(constant: Module).returns(T::Enumerable[Symbol]) } | |
def direct_public_instance_methods_for(constant) | |
result = constant.public_instance_methods | |
constant.included_modules.each do |included_mod| | |
result -= included_mod.public_instance_methods | |
end | |
result -= T.must(constant.superclass).public_instance_methods if constant.is_a?(Class) && constant.superclass | |
result | |
end | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment