Skip to content

Instantly share code, notes, and snippets.

@BlinkyStitt
Created March 11, 2022 00:11
Show Gist options
  • Save BlinkyStitt/b95840d089a8a38a613b9271858d0a66 to your computer and use it in GitHub Desktop.
Save BlinkyStitt/b95840d089a8a38a613b9271858d0a66 to your computer and use it in GitHub Desktop.
import brownie
import eth_abi
import pytest
from hexbytes import HexBytes
from flashprofits import weiroll
@pytest.fixture
def alice(accounts):
return accounts[0]
@pytest.fixture
def math(alice):
brownie_contract = alice.deploy(brownie.WeirollMath)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def subplanContract(alice):
brownie_contract = alice.deploy(brownie.WeirollTestSubplan)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def multiSubplanContract(alice):
brownie_contract = alice.deploy(brownie.WeirollTestMultiSubplan)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def multiStateSubplanContract(alice):
brownie_contract = alice.deploy(brownie.WeirollTestMultiStateSubplan)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def readonlySubplanContract(alice):
brownie_contract = alice.deploy(brownie.WeirollTestReadonlySubplan)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def testContract(alice):
brownie_contract = alice.deploy(brownie.WeirollTest)
return weiroll.WeirollContract.createLibrary(brownie_contract)
@pytest.fixture
def strings(alice):
brownie_contract = alice.deploy(brownie.WeirollStrings)
return weiroll.WeirollContract.createLibrary(brownie_contract)
def test_weiroll_contract(math):
assert hasattr(math, "add")
result = math.add(1, 2)
assert result.contract == math
# assert result.fragment.function.signature == math.add.signature
assert result.fragment.inputs == ["uint256", "uint256"]
assert result.fragment.name == "add"
assert result.fragment.outputs == ["uint256"]
assert result.fragment.signature == "0x771602f7"
assert result.callvalue == 0
assert result.flags == weiroll.CommandFlags.DELEGATECALL
args = result.args
assert len(args) == 2
assert args[0].param == "uint256"
assert args[0].value == eth_abi.encode_single("uint256", 1)
assert args[1].param == "uint256"
assert args[1].value == eth_abi.encode_single("uint256", 2)
def test_weiroll_planner_adds(alice, math):
planner = weiroll.WeirollPlanner(alice)
sum1 = planner.add(math.add(1, 2))
sum2 = planner.add(math.add(3, 4))
planner.add(math.add(sum1, sum2))
assert len(planner.commands) == 3
def test_weiroll_planner_simple_program(alice, math):
planner = weiroll.WeirollPlanner(alice)
planner.add(math.add(1, 2))
commands, state = planner.plan()
assert len(commands) == 1
# TODO: hexconcat?
assert commands[0] == weiroll.hexConcat("0x771602f7000001ffffffffff", math.address)
assert len(state) == 2
assert state[0] == eth_abi.encode_single("uint", 1)
assert state[1] == eth_abi.encode_single("uint", 2)
def test_weiroll_deduplicates_identical_literals(alice, math):
planner = weiroll.WeirollPlanner(alice)
planner.add(math.add(1, 1))
commands, state = planner.plan()
assert len(commands) == 1
assert len(state) == 1
assert state[0] == eth_abi.encode_single("uint", 1)
def test_weiroll_with_return_value(alice, math):
planner = weiroll.WeirollPlanner(alice)
sum1 = planner.add(math.add(1, 2))
planner.add(math.add(sum1, 3))
commands, state = planner.plan()
assert len(commands) == 2
assert commands[0] == weiroll.hexConcat("0x771602f7000001ffffffff01", math.address)
assert commands[1] == weiroll.hexConcat("0x771602f7000102ffffffffff", math.address)
assert len(state) == 3
assert state[0] == eth_abi.encode_single("uint", 1)
assert state[1] == eth_abi.encode_single("uint", 2)
assert state[2] == eth_abi.encode_single("uint", 3)
def test_weiroll_with_state_slots_for_intermediate_values(alice, math):
planner = weiroll.WeirollPlanner(alice)
sum1 = planner.add(math.add(1, 1))
planner.add(math.add(1, sum1))
commands, state = planner.plan()
assert len(commands) == 2
assert commands[0] == weiroll.hexConcat("0x771602f7000000ffffffff01", math.address)
assert commands[1] == weiroll.hexConcat("0x771602f7000001ffffffffff", math.address)
assert len(state) == 2
assert state[0] == eth_abi.encode_single("uint", 1)
assert state[1] == b""
@pytest.mark.parametrize(
"param,value,expected",
[
(
"string",
"Hello, world!",
"0x000000000000000000000000000000000000000000000000000000000000000d48656c6c6f2c20776f726c642100000000000000000000000000000000000000",
),
],
)
def test_weiroll_abi_encode_single(param, value, expected):
expected = HexBytes(expected)
print("expected:", expected)
literalValue = HexBytes(eth_abi.encode_single(param, value))
print("literalValue:", literalValue)
assert literalValue == expected
def test_weiroll_takes_dynamic_arguments(alice, strings):
test_str = "Hello, world!"
planner = weiroll.WeirollPlanner(alice)
planner.add(strings.strlen(test_str))
commands, state = planner.plan()
assert len(commands) == 1
assert commands[0] == weiroll.hexConcat("0x367bbd780080ffffffffffff", strings.address)
print(state)
assert len(state) == 1
assert state[0] == eth_abi.encode_single("string", test_str)
def test_weiroll_returns_dynamic_arguments(alice, strings):
planner = weiroll.WeirollPlanner(alice)
planner.add(strings.strcat("Hello, ", "world!"))
commands, state = planner.plan()
assert len(commands) == 1
assert commands[0] == weiroll.hexConcat("0xd824ccf3008081ffffffffff", strings.address)
assert len(state) == 2
assert state[0] == eth_abi.encode_single("string", "Hello, ")
assert state[1] == eth_abi.encode_single("string", "world!")
def test_weiroll_takes_dynamic_argument_from_a_return_value(alice, strings):
planner = weiroll.WeirollPlanner(alice)
test_str = planner.add(strings.strcat("Hello, ", "world!"))
planner.add(strings.strlen(test_str))
commands, state = planner.plan()
assert len(commands) == 2
assert commands[0] == weiroll.hexConcat("0xd824ccf3008081ffffffff81", strings.address)
assert commands[1] == weiroll.hexConcat("0x367bbd780081ffffffffffff", strings.address)
assert len(state) == 2
assert state[0] == eth_abi.encode_single("string", "Hello, ")
assert state[1] == eth_abi.encode_single("string", "world!")
def test_weiroll_argument_counts_match(math):
with pytest.raises(ValueError):
math.add(1)
def test_weiroll_func_takes_and_replaces_current_state(alice, testContract):
planner = weiroll.WeirollPlanner(alice)
planner.replaceState(testContract.useState(planner.state))
commands, state = planner.plan()
assert len(commands) == 1
assert commands[0] == weiroll.hexConcat("0x08f389c800fefffffffffffe", testContract.address)
assert len(state) == 0
def test_weiroll_supports_subplan(alice, math, subplanContract):
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(subplanContract.execute(subplanner, subplanner.state))
commands, state = planner.plan()
assert commands == [weiroll.hexConcat("0xde792d5f0082fefffffffffe", subplanContract.address)]
assert len(state) == 3
assert state[0] == eth_abi.encode_single("uint", 1)
assert state[1] == eth_abi.encode_single("uint", 2)
# TODO: javascript test is more complicated than this. but i think this is fine?
assert state[2] == weiroll.hexConcat("0x771602f7000001ffffffffff", math.address)
def test_weiroll_subplan_allows_return_in_parent_scope(alice, math, subplanContract):
subplanner = weiroll.WeirollPlanner(alice)
sum = subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(subplanContract.execute(subplanner, subplanner.state))
planner.add(math.add(sum, 3))
commands, _ = planner.plan()
assert len(commands) == 2
# Invoke subplanner
assert commands[0] == weiroll.hexConcat("0xde792d5f0083fefffffffffe", subplanContract.address)
# sum + 3
assert commands[1] == weiroll.hexConcat("0x771602f7000102ffffffffff", math.address)
def test_weiroll_return_values_across_scopes(alice, math, subplanContract):
subplanner1 = weiroll.WeirollPlanner(alice)
sum = subplanner1.add(math.add(1, 2))
subplanner2 = weiroll.WeirollPlanner(alice)
subplanner2.add(math.add(sum, 3))
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(subplanContract.execute(subplanner1, subplanner1.state))
planner.addSubplan(subplanContract.execute(subplanner2, subplanner2.state))
commands, state = planner.plan()
assert len(commands) == 2
assert commands[0] == weiroll.hexConcat("0xde792d5f0083fefffffffffe", subplanContract.address)
assert commands[1] == weiroll.hexConcat("0xde792d5f0084fefffffffffe", subplanContract.address)
assert len(state) == 5
# TODO: javascript tests were more complex than this
assert state[4] == weiroll.hexConcat("0x771602f7000102ffffffffff", math.address)
def test_weiroll_return_values_must_be_defined(alice, math):
subplanner = weiroll.WeirollPlanner(alice)
sum = subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
planner.add(math.add(sum, 3))
with pytest.raises(ValueError, match="Return value from 'add' is not visible here"):
planner.plan()
def test_weiroll_add_subplan_needs_args(alice, math, subplanContract):
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
with pytest.raises(ValueError, match="Subplans must take planner and state arguments"):
planner.addSubplan(subplanContract.execute(subplanner, []))
with pytest.raises(ValueError, match="Subplans must take planner and state arguments"):
planner.addSubplan(subplanContract.execute([], subplanner.state))
def test_weiroll_doesnt_allow_multiple_subplans_per_call(alice, math, multiSubplanContract):
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
with pytest.raises(ValueError, match="Subplans can only take one planner argument"):
planner.addSubplan(multiSubplanContract.execute(subplanner, subplanner, subplanner.state))
def test_weiroll_doesnt_allow_state_array_per_call(alice, math, multiStateSubplanContract):
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
with pytest.raises(ValueError, match="Subplans can only take one state argument"):
planner.addSubplan(multiStateSubplanContract.execute(subplanner, subplanner.state, subplanner.state))
def test_weiroll_subplan_has_correct_return_type(alice, math):
badSubplanContract = weiroll.WeirollContract.createLibrary(alice.deploy(brownie.WeirollBadSubplan))
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
with pytest.raises(ValueError, match=r"Subplans must return a bytes\[\] replacement state or nothing"):
planner.addSubplan(badSubplanContract.execute(subplanner, subplanner.state))
def test_forbid_infinite_loops(alice, subplanContract):
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(subplanContract.execute(planner, planner.state))
with pytest.raises(ValueError, match="A planner cannot contain itself"):
planner.plan()
def test_subplans_without_returns(alice, math, readonlySubplanContract):
subplanner = weiroll.WeirollPlanner(alice)
subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(readonlySubplanContract.execute(subplanner, subplanner.state))
commands, _ = planner.plan()
assert len(commands) == 1
commands[0] == weiroll.hexConcat("0xde792d5f0082feffffffffff", readonlySubplanContract.address)
def test_read_only_subplans_requirements(alice, math, readonlySubplanContract):
"""it does not allow return values from inside read-only subplans to be used outside them"""
subplanner = weiroll.WeirollPlanner(alice)
sum = subplanner.add(math.add(1, 2))
planner = weiroll.WeirollPlanner(alice)
planner.addSubplan(readonlySubplanContract.execute(subplanner, subplanner.state))
planner.add(math.add(sum, 3))
with pytest.raises(ValueError, match="Return value from 'add' is not visible here"):
planner.plan()
@pytest.mark.xfail(reason="need to write this")
def test_plan_with_loop(alice):
target_calldata = "0xc6b6816900000000000000000000000000000000000000000000054b40b1f852bda0"
"""
[
'0x0000000000000000000000000000000000000000000000000000000000000005',
'0x000000000000000000000000cecad69d7d4ed6d52efcfa028af8732f27e08f70',
'0x0000000000000000000000000000000000000000000000000000000000000022c6b6816900000000000000000000000000000000000000000000054b40b1f852bda0000000000000000000000000000000000000000000000000000000000000'
]
"""
planner = weiroll.WeirollPlanner(alice)
raise NotImplementedError
def _test_more(math):
# TODO: test for curve add_liquidity encoding
"""
expect(() => planner.plan()).to.throw(
'Return value from "add" is not visible here'
);
});
it('plans CALLs', () => {
let Math = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, mathABI.abi)
);
const planner = new Planner();
planner.add(Math.add(1, 2));
const { commands } = planner.plan();
expect(commands.length).to.equal(1);
expect(commands[0]).to.equal(
'0x771602f7010001ffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'
);
});
it('plans STATICCALLs', () => {
let Math = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, mathABI.abi),
CommandFlags.STATICCALL
);
const planner = new Planner();
planner.add(Math.add(1, 2));
const { commands } = planner.plan();
expect(commands.length).to.equal(1);
expect(commands[0]).to.equal(
'0x771602f7020001ffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'
);
});
it('plans STATICCALLs via .staticcall()', () => {
let Math = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, mathABI.abi)
);
const planner = new Planner();
planner.add(Math.add(1, 2).staticcall());
const { commands } = planner.plan();
expect(commands.length).to.equal(1);
expect(commands[0]).to.equal(
'0x771602f7020001ffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'
);
});
it('plans CALLs with value', () => {
const Test = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, ['function deposit(uint x) payable'])
);
const planner = new Planner();
planner.add(Test.deposit(123).withValue(456));
const { commands } = planner.plan();
expect(commands.length).to.equal(1);
expect(commands[0]).to.equal(
'0xb6b55f25030001ffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'
);
});
it('allows returns from other calls to be used for the value parameter', () => {
const Test = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, ['function deposit(uint x) payable'])
);
const planner = new Planner();
const sum = planner.add(Math.add(1, 2));
planner.add(Test.deposit(123).withValue(sum));
const { commands } = planner.plan();
expect(commands.length).to.equal(2);
expect(commands).to.deep.equal([
'0x771602f7000001ffffffff01eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'0xb6b55f25030102ffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
]);
});
it('does not allow value-calls for DELEGATECALL or STATICCALL', () => {
expect(() => Math.add(1, 2).withValue(3)).to.throw(
'Only CALL operations can send value'
);
const StaticMath = Contract.createContract(
new ethers.Contract(SAMPLE_ADDRESS, mathABI.abi),
CommandFlags.STATICCALL
);
expect(() => StaticMath.add(1, 2).withValue(3)).to.throw(
'Only CALL operations can send value'
);
});
it('does not allow making DELEGATECALL static', () => {
expect(() => Math.add(1, 2).staticcall()).to.throw(
'Only CALL operations can be made static'
);
});
it('uses extended commands where necessary', () => {
const Test = Contract.createLibrary(
new ethers.Contract(SAMPLE_ADDRESS, [
'function test(uint a, uint b, uint c, uint d, uint e, uint f, uint g) returns(uint)',
])
);
const planner = new Planner();
planner.add(Test.test(1, 2, 3, 4, 5, 6, 7));
const { commands } = planner.plan();
expect(commands.length).to.equal(2);
expect(commands[0]).to.equal(
'0xe473580d40000000000000ffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'
);
expect(commands[1]).to.equal(
'0x00010203040506ffffffffffffffffffffffffffffffffffffffffffffffffff'
);
});
it('supports capturing the whole return value as a bytes', () => {
const Test = Contract.createLibrary(
new ethers.Contract(SAMPLE_ADDRESS, [
'function returnsTuple() returns(uint a, bytes32[] b)',
'function acceptsBytes(bytes raw)',
])
);
const planner = new Planner();
const ret = planner.add(Test.returnsTuple().rawValue());
planner.add(Test.acceptsBytes(ret));
const { commands } = planner.plan();
expect(commands).to.deep.equal([
'0x61a7e05e80ffffffffffff80eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'0x3e9ef66a0080ffffffffffffeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
]);
});
"""
import re
from collections import defaultdict, namedtuple
from enum import IntEnum, IntFlag
from functools import cache
from typing import Optional
import brownie
import eth_abi
import eth_abi.packed
from brownie.convert.utils import get_type_strings
from brownie.network.contract import OverloadedMethod
from hexbytes import HexBytes
from .web3_helpers import MAX_UINT256
# TODO: real types?
Value = namedtuple("Value", "param")
LiteralValue = namedtuple("LiteralValue", "param,value")
ReturnValue = namedtuple("ReturnValue", "param,command")
def simple_type_strings(inputs) -> tuple[Optional[list[str]], Optional[list[int]]]:
"""cut state variables that are too long into 32 byte chunks.
related: https://github.com/weiroll/weiroll.js/pull/34
"""
if not inputs:
return None, None
simple_inputs = []
simple_sizes = []
for i in inputs:
if i.endswith("]") and not i.endswith("[]"):
# fixed size array. cut it up
# TODO: i've seen a warning about "\[" is not valid, but i get errors without
m = re.match(r"([a-z0-9]+)\[([0-9]+)\]", i)
size = int(m.group(2))
simple_inputs.extend([m.group(1)] * size)
simple_sizes.append(size)
else:
simple_inputs.append(i)
simple_sizes.append(1)
if all([s == 1 for s in simple_sizes]):
# if no inputs or all the inputs are easily handled sizes, we don't need to simplify them
# we don't clear simple_inputs because its simpler for that to just be a copy of self.inputs
simple_sizes = None
return simple_inputs, simple_sizes
def simple_args(simple_sizes, args):
"""split up complex types into 32 byte chunks that weiroll state can handle."""
if not simple_sizes:
# no need to handle anything specially
return args
simplified = []
for i, size in enumerate(simple_sizes):
if size == 1:
# no need to do anything fancy
simplified.append(args[i])
else:
simplified.extend(args[i])
return simplified
# TODO: not sure about this class. its mostly here because this is how the javascript sdk works. now that this works, i think we can start refactoring to use brownie more directly
class FunctionFragment:
def __init__(self, brownieContract: brownie.Contract, selector):
function_name = brownieContract.selectors[selector]
function = getattr(brownieContract, function_name)
if isinstance(function, OverloadedMethod):
overloaded_func = None
for m in function.methods.values():
# TODO: everyone is inconsistent about signature vs selector vs name
if m.signature == selector:
overloaded_func = m
break
assert overloaded_func
function = overloaded_func
self.function = function
self.name = function_name
self.signature = function.signature
self.inputs = get_type_strings(function.abi["inputs"])
# look at the inputs that aren't dynamic types but also aren't 32 bytes long and cut them up
self.simple_inputs, self.simple_sizes = simple_type_strings(self.inputs)
self.outputs = get_type_strings(function.abi["outputs"])
# TODO: do something to handle outputs of uncommon types?
def encode_args(self, *args):
if len(args) != len(self.inputs):
raise ValueError(f"Function {self.name} has {len(self.inputs)} arguments but {len(self.args)} provided")
# split up complex types into 32 byte chunks that weiroll state can handle
args = simple_args(self.simple_sizes, args)
return [encodeArg(arg, self.simple_inputs[i]) for (i, arg) in enumerate(args)]
class StateValue:
def __init__(self):
self.param = "bytes[]"
class SubplanValue:
def __init__(self, planner):
self.param = "bytes[]"
self.planner = planner
# TODO: use python ABC or something like that?
def isValue(arg):
if isinstance(arg, Value):
return True
if isinstance(arg, LiteralValue):
return True
if isinstance(arg, ReturnValue):
return True
if isinstance(arg, StateValue):
return True
if isinstance(arg, SubplanValue):
return True
return False
# TODO: this needs tests! I'm 90% sure this is wrong for lists
# TODO: does eth_utils not already have this? it seems like other people should have written something like this
def hexConcat(*items) -> HexBytes:
result = b""
for item in items:
if isinstance(item, list):
item = hexConcat(*item)
else:
item = HexBytes(item)
result += bytes(item)
return HexBytes(result)
class CommandFlags(IntFlag):
# Specifies that a call should be made using the DELEGATECALL opcode
DELEGATECALL = 0x00
# Specifies that a call should be made using the CALL opcode
CALL = 0x01
# Specifies that a call should be made using the STATICCALL opcode
STATICCALL = 0x02
# Specifies that a call should be made using the CALL opcode, and that the first argument will be the value to send
CALL_WITH_VALUE = 0x03
# A bitmask that selects calltype flags
CALLTYPE_MASK = 0x03
# Specifies that this is an extended command, with an additional command word for indices. Internal use only.
EXTENDED_COMMAND = 0x40
# Specifies that the return value of this call should be wrapped in a `bytes`. Internal use only.
TUPLE_RETURN = 0x80
class FunctionCall:
def __init__(self, contract, flags: CommandFlags, fragment: FunctionFragment, args, callvalue=0):
self.contract = contract
self.flags = flags
self.fragment = fragment
self.args = args
self.callvalue = callvalue
def withValue(self, value):
"""
Returns a new [[FunctionCall]] that sends value with the call.
@param value The value (in wei) to send with the call
"""
if (self.flags & CommandFlags.CALLTYPE_MASK) != CommandFlags.CALL and (
self.flags & CommandFlags.CALLTYPE_MASK
) != CommandFlags.CALL_WITH_VALUE:
raise ValueError("Only CALL operations can send value")
return self.__class__(
self.contract,
(self.flags & ~CommandFlags.CALLTYPE_MASK) | CommandFlags.CALL_WITH_VALUE,
self.fragment,
self.args,
eth_abi.encode_single("uint", value),
)
def rawValue(self):
"""
Returns a new [[FunctionCall]] whose return value will be wrapped as a `bytes`.
This permits capturing the return values of functions with multiple return parameters,
which weiroll does not otherwise support.
"""
return self.__class__(
self.contract,
self.flags | CommandFlags.TUPLE_RETURN,
self.fragment,
self.args,
self.callvalue,
)
def staticcall(self):
"""Returns a new [[FunctionCall]] that executes a STATICCALL instead of a regular CALL."""
if (self.flags & CommandFlags.CALLTYPE_MASK) != CommandFlags.CALL:
raise ValueError("Only CALL operations can be made static")
return self.__class__(
self.contract,
(self.flags & ~CommandFlags.CALLTYPE_MASK) | CommandFlags.STATICCALL,
self.fragment,
self.args,
self.callvalue,
)
# TODO: this is probably not an accurate port. think about this more
def isDynamicType(param) -> bool:
if param.endswith("]"):
param = "array"
if param.startswith("tuple"):
param = "tuple"
return param in ["string", "bytes", "array", "tuple"]
def encodeArg(arg, param):
if isValue(arg):
if arg.param != param:
raise ValueError(f"Cannot pass value of type ${arg.param} to input of type ${param}")
return arg
if isinstance(arg, WeirollPlanner):
return SubplanValue(arg)
return LiteralValue(param, eth_abi.encode_single(param, arg))
class WeirollContract:
"""
* Provides a dynamically created interface to interact with Ethereum contracts via weiroll.
*
* Once created using the constructor or the [[Contract.createContract]] or [[Contract.createLibrary]]
* functions, the returned object is automatically populated with methods that match those on the
* supplied contract. For instance, if your contract has a method `add(uint, uint)`, you can call it on the
* [[Contract]] object:
* ```typescript
* // Assumes `Math` is an ethers.js Contract instance.
* const math = Contract.createLibrary(Math);
* const result = math.add(1, 2);
* ```
*
* Calling a contract function returns a [[FunctionCall]] object, which you can pass to [[Planner.add]],
* [[Planner.addSubplan]], or [[Planner.replaceState]] to add to the sequence of calls to plan.
"""
def __init__(self, brownieContract: brownie.Contract, commandFlags: CommandFlags = 0):
self.brownieContract = brownieContract
self.address = brownieContract.address
self.commandFlags = commandFlags
self.functions = {} # aka functionsBySelector
self.functionsBySignature = {}
self.fragmentsBySelector = {}
selectorsByName = defaultdict(list)
for selector, name in self.brownieContract.selectors.items():
fragment = FunctionFragment(self.brownieContract, selector)
# Check that the signature is unique; if not the ABI generation has
# not been cleaned or may be incorrectly generated
if selector in self.functions:
raise ValueError(f"Duplicate ABI entry for selector: {selector}")
self.fragmentsBySelector[selector] = fragment
plan_fn = buildCall(self, fragment)
# save this plan helper function fragment in self.functions
self.functions[selector] = plan_fn
# make the plan helper function available on self by selector
setattr(self, selector, plan_fn)
# Track unique names; we only expose bare named functions if they are ambiguous
selectorsByName[name].append(selector)
self.functionsByUniqueName = {}
for name, selectors in selectorsByName.items():
# Ambiguous names to not get attached as bare names
if len(selectors) == 1:
if hasattr(self, name):
# TODO: i think this is impossible
raise ValueError("duplicate name!")
plan_fn = self.functions[selectors[0]]
# make the plan helper function available on self
setattr(self, name, plan_fn)
self.functionsByUniqueName[name] = plan_fn
# attach full signatures (for methods with duplicate names)
for selector in selectors:
fragment = self.fragmentsBySelector[selector]
signature = name + "(" + ",".join(fragment.inputs) + ")"
plan_fn = self.functions[selector]
self.functionsBySignature[signature] = plan_fn
@classmethod
@cache
def createContract(
cls,
contract: brownie.Contract,
commandflags=CommandFlags.CALL,
):
"""
Creates a [[Contract]] object from an ethers.js contract.
All calls on the returned object will default to being standard CALL operations.
Use this when you want your weiroll script to call a standard external contract.
@param contract The ethers.js Contract object to wrap.
@param commandflags Optionally specifies a non-default call type to use, such as
[[CommandFlags.STATICCALL]].
"""
assert commandflags != CommandFlags.DELEGATECALL
return cls(
contract,
commandflags,
)
@classmethod
@cache
def createLibrary(
cls,
contract: brownie.Contract,
):
"""
* Creates a [[Contract]] object from an ethers.js contract.
* All calls on the returned object will default to being DELEGATECALL operations.
* Use this when you want your weiroll script to call a library specifically designed
* for use with weiroll.
* @param contract The ethers.js Contract object to wrap.
"""
return cls(contract, CommandFlags.DELEGATECALL)
# TODO: port getInterface?
# TODO: not sure about this one. this was just how the javascript code worked, but can probably be refactored
def buildCall(contract: WeirollContract, fragment: FunctionFragment):
def _call(*args) -> FunctionCall:
if len(args) != len(fragment.inputs):
raise ValueError(f"Function {fragment.name} has {len(fragment.inputs)} arguments but {len(args)} provided")
# TODO: maybe this should just be fragment.encode_args()
encodedArgs = fragment.encode_args(*args)
return FunctionCall(
contract,
contract.commandFlags,
fragment,
encodedArgs,
)
return _call
class CommandType(IntEnum):
CALL = 1
RAWCALL = 2
SUBPLAN = 3
Command = namedtuple("Command", "call,type")
# returnSlotMap: Maps from a command to the slot used for its return value
# literalSlotMap: Maps from a literal to the slot used to store it
# freeSlots: An array of unused state slots
# stateExpirations: Maps from a command to the slots that expire when it's executed
# commandVisibility: Maps from a command to the last command that consumes its output
# state: The initial state array
PlannerState = namedtuple(
"PlannerState",
"returnSlotMap, literalSlotMap, freeSlots, stateExpirations, commandVisibility, state",
)
def padArray(a, length, padValue) -> list:
return a + [padValue] * (length - len(a))
class WeirollPlanner:
def __init__(self, clone):
self.state = StateValue()
self.commands: list[Command] = []
self.unlimited_approvals = set()
self.clone = clone
def approve(self, token: brownie.Contract, spender: str, wei_needed, approve_wei=None) -> Optional[ReturnValue]:
key = (token, self.clone, spender)
if approve_wei is None:
approve_wei = MAX_UINT256
if key in self.unlimited_approvals and approve_wei != 0:
# we already planned an infinite approval for this token (and we aren't trying to set the approval to 0)
return
# check current allowance
if token.allowance(self.clone, spender) >= wei_needed:
return
if approve_wei == MAX_UINT256:
self.unlimited_approvals.add(key)
return self.call(token, "approve", spender, approve_wei)
def call(self, brownieContract: brownie.Contract, func_name, *args):
"""func_name can be just the name, or it can be the full signature.
If there are multiple functions with the same name, you must use the signature.
TODO: brownie has some logic for figuring out which overloaded method to use. we should use that here
"""
weirollContract = WeirollContract.createContract(brownieContract)
if func_name.endswith(")"):
# TODO: would be interesting to look at args and do this automatically
func = weirollContract.functionsBySignature[func_name]
else:
func = weirollContract.functionsByUniqueName[func_name]
return self.add(func(*args))
def delegatecall(self, brownieContract: brownie.Contract, func_name, *args):
contract = WeirollContract.createLibrary(brownieContract)
if func_name in contract.functionsByUniqueName:
func = contract.functionsByUniqueName[func_name]
elif func_name in contract.functionsBySignature:
func = contract.functionsBySignature[func_name]
else:
# print("func_name:", func_name)
# print("functionsByUniqueName:", contract.functionsByUniqueName)
# print("functionsBySignature:", contract.functionsBySignature)
raise ValueError(f"Unknown func_name ({func_name}) on {brownieContract}")
return self.add(func(*args))
def add(self, call: FunctionCall) -> Optional[ReturnValue]:
"""
* Adds a new function call to the planner. Function calls are executed in the order they are added.
*
* If the function call has a return value, `add` returns an object representing that value, which you
* can pass to subsequent function calls. For example:
* ```typescript
* const math = Contract.createLibrary(Math); // Assumes `Math` is an ethers.js contract object
* const events = Contract.createLibrary(Events); // Assumes `Events` is an ethers.js contract object
* const planner = new Planner();
* const sum = planner.add(math.add(21, 21));
* planner.add(events.logUint(sum));
* ```
* @param call The [[FunctionCall]] to add to the planner
* @returns An object representing the return value of the call, or null if it does not return a value.
"""
command = Command(call, CommandType.CALL)
self.commands.append(command)
for arg in call.args:
if isinstance(arg, SubplanValue):
raise ValueError("Only subplans can have arguments of type SubplanValue")
if call.flags & CommandFlags.TUPLE_RETURN:
return ReturnValue("bytes", command)
# TODO: test this more
if len(call.fragment.outputs) != 1:
return None
# print("call fragment outputs", call.fragment.outputs)
return ReturnValue(call.fragment.outputs[0], command)
def subcall(self, brownieContract: brownie.Contract, func_name, *args):
"""
* Adds a call to a subplan. This has the effect of instantiating a nested instance of the weiroll
* interpreter, and is commonly used for functionality such as flashloans, control flow, or anywhere
* else you may need to execute logic inside a callback.
*
* A [[FunctionCall]] passed to [[Planner.addSubplan]] must take another [[Planner]] object as one
* argument, and a placeholder representing the planner state, accessible as [[Planner.state]], as
* another. Exactly one of each argument must be provided.
*
* At runtime, the subplan is replaced by a list of commands for the subplanner (type `bytes32[]`),
* and `planner.state` is replaced by the current state of the parent planner instance (type `bytes[]`).
*
* If the `call` returns a `bytes[]`, this will be used to replace the parent planner's state after
* the call to the subplanner completes. Return values defined inside a subplan may be used outside that
* subplan - both in the parent planner and in subsequent subplans - only if the `call` returns the
* updated planner state.
*
* Example usage:
* ```
* const exchange = Contract.createLibrary(Exchange); // Assumes `Exchange` is an ethers.js contract
* const events = Contract.createLibrary(Events); // Assumes `Events` is an ethers.js contract
* const subplanner = new Planner();
* const outqty = subplanner.add(exchange.swap(tokenb, tokena, qty));
*
* const planner = new Planner();
* planner.addSubplan(exchange.flashswap(tokena, tokenb, qty, subplanner, planner.state));
* planner.add(events.logUint(outqty)); // Only works if `exchange.flashswap` returns updated state
* ```
* @param call The [[FunctionCall]] to add to the planner.
"""
contract = WeirollContract.createContract(brownieContract)
func = getattr(contract, func_name)
func_call = func(*args)
return self.addSubplan(func_call)
def subdelegatecall(self, brownieContract: brownie.Contract, func_name, *args):
contract = WeirollContract.createLibrary(brownieContract)
func = getattr(contract, func_name)
func_call = func(*args)
return self.addSubplan(func_call)
def addSubplan(self, call: FunctionCall):
hasSubplan = False
hasState = False
for arg in call.args:
if isinstance(arg, SubplanValue):
if hasSubplan:
raise ValueError("Subplans can only take one planner argument")
hasSubplan = True
elif isinstance(arg, StateValue):
if hasState:
raise ValueError("Subplans can only take one state argument")
hasState = True
if not hasSubplan or not hasState:
raise ValueError("Subplans must take planner and state arguments")
if call.fragment.outputs and len(call.fragment.outputs) == 1 and call.fragment.outputs[0] != "bytes[]":
raise ValueError("Subplans must return a bytes[] replacement state or nothing")
self.commands.append(Command(call, CommandType.SUBPLAN))
def replaceState(self, call: FunctionCall):
"""
* Executes a [[FunctionCall]], and replaces the planner state with the value it
* returns. This can be used to execute functions that make arbitrary changes to
* the planner state. Note that the planner library is not aware of these changes -
* so it may produce invalid plans if you don't know what you're doing.
* @param call The [[FunctionCall]] to execute
"""
if (call.fragment.outputs and len(call.fragment.outputs) != 1) or call.fragment.outputs[0] != "bytes[]":
raise ValueError("Function replacing state must return a bytes[]")
self.commands.append(Command(call, CommandType.RAWCALL))
def _preplan(self, commandVisibility, literalVisibility, seen=None, planners=None):
if seen is None:
seen: set[Command] = set()
if planners is None:
planners: set[WeirollPlanner] = set()
if self in planners:
raise ValueError("A planner cannot contain itself")
planners.add(self)
# Build visibility maps
for command in self.commands:
inargs = command.call.args
if command.call.flags & CommandFlags.CALLTYPE_MASK == CommandFlags.CALL_WITH_VALUE:
if not command.call.callvalue:
raise ValueError("Call with value must have a value parameter")
inargs = [command.call.callvalue] + inargs
for arg in inargs:
if isinstance(arg, ReturnValue):
if not arg.command in seen:
raise ValueError(f"Return value from '{arg.command.call.fragment.name}' is not visible here")
commandVisibility[arg.command] = command
elif isinstance(arg, LiteralValue):
literalVisibility[arg.value] = command
elif isinstance(arg, SubplanValue):
subplanSeen = seen # do not copy
if not command.call.fragment.outputs:
# Read-only subplan; return values aren't visible externally
subplanSeen = set(seen)
arg.planner._preplan(commandVisibility, literalVisibility, subplanSeen, planners)
elif not isinstance(arg, StateValue):
raise ValueError(f"Unknown function argument type '{arg}'")
seen.add(command)
return commandVisibility, literalVisibility
def _buildCommandArgs(self, command: Command, returnSlotMap, literalSlotMap, state):
# Build a list of argument value indexes
inargs = command.call.args
if command.call.flags & CommandFlags.CALLTYPE_MASK == CommandFlags.CALL_WITH_VALUE:
if not command.call.callvalue:
raise ValueError("Call with value must have a value parameter")
inargs = [command.call.callvalue] + inargs
args: list[int] = []
for arg in inargs:
if isinstance(arg, ReturnValue):
slot = returnSlotMap[arg.command]
elif isinstance(arg, LiteralValue):
slot = literalSlotMap[arg.value]
elif isinstance(arg, StateValue):
slot = 0xFE
elif isinstance(arg, SubplanValue):
# buildCommands has already built the subplan and put it in the last state slot
slot = len(state) - 1
else:
raise ValueError(f"Unknown function argument type {arg}")
if isDynamicType(arg.param):
slot |= 0x80
args.append(slot)
return args
def _buildCommands(self, ps: PlannerState) -> list[str]:
encodedCommands = []
for command in self.commands:
if command.type == CommandType.SUBPLAN:
# find the subplan
subplanner = next(arg for arg in command.call.args if isinstance(arg, SubplanValue)).planner
subcommands = subplanner._buildCommands(ps)
ps.state.append(HexBytes(eth_abi.encode_single("bytes32[]", subcommands))[32:])
# The slot is no longer needed after this command
ps.freeSlots.append(len(ps.state) - 1)
flags = command.call.flags
args = self._buildCommandArgs(command, ps.returnSlotMap, ps.literalSlotMap, ps.state)
if len(args) > 6:
flags |= CommandFlags.EXTENDED_COMMAND
# Add any newly unused state slots to the list
ps.freeSlots.extend(ps.stateExpirations[command])
ret = 0xFF
if command in ps.commandVisibility:
if command.type in [CommandType.RAWCALL, CommandType.SUBPLAN]:
raise ValueError(
f"Return value of {command.call.fragment.name} cannot be used to replace state and in another function"
)
ret = len(ps.state)
if len(ps.freeSlots) > 0:
ret = ps.freeSlots.pop()
# store the slot mapping
ps.returnSlotMap[command] = ret
# make the slot available when it's not needed
expiryCommand = ps.commandVisibility[command]
ps.stateExpirations[expiryCommand].append(ret)
if ret == len(ps.state):
ps.state.append(b"")
if (
command.call.fragment.outputs and isDynamicType(command.call.fragment.outputs[0])
) or command.call.flags & CommandFlags.TUPLE_RETURN != 0:
ret |= 0x80
elif command.type in [CommandType.RAWCALL, CommandType.SUBPLAN]:
if command.call.fragment.outputs and len(command.call.fragment.outputs) == 1:
ret = 0xFE
if flags & CommandFlags.EXTENDED_COMMAND == CommandFlags.EXTENDED_COMMAND:
# extended command
encodedCommands.extend(
[
hexConcat(
command.call.fragment.signature,
flags,
[0, 0, 0, 0, 0, 0],
ret,
command.call.contract.address,
),
hexConcat(padArray(args, 32, 0xFF)),
]
)
else:
# standard command
encodedCommands.append(
hexConcat(
command.call.fragment.signature,
flags,
padArray(args, 6, 0xFF),
ret,
command.call.contract.address,
)
)
return encodedCommands
def plan(self) -> tuple[list[str], list[str]]:
# Tracks the last time a literal is used in the program
literalVisibility: dict[str, Command] = {}
# Tracks the last time a command's output is used in the program
commandVisibility: dict[Command, Command] = {}
self._preplan(commandVisibility, literalVisibility)
# Maps from commands to the slots that expire on execution (if any)
stateExpirations: dict[Command, list[int]] = defaultdict(list)
# Tracks the state slot each literal is stored in
literalSlotMap: dict[str, int] = {}
state: list[str] = []
# Prepopulate the state and state expirations with literals
for (literal, lastCommand) in literalVisibility.items():
slot = len(state)
state.append(literal)
literalSlotMap[literal] = slot
stateExpirations[lastCommand].append(slot)
ps: PlannerState = PlannerState(
returnSlotMap={},
literalSlotMap=literalSlotMap,
freeSlots=[],
stateExpirations=stateExpirations,
commandVisibility=commandVisibility,
state=state,
)
encodedCommands = self._buildCommands(ps)
return encodedCommands, state
@BlinkyStitt
Copy link
Author

I'm sharing this code as MIT licensed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment