Skip to content

Instantly share code, notes, and snippets.

@FoobarProtocol
Created October 21, 2023 23:47
Show Gist options
  • Save FoobarProtocol/7c1a7eac7d85e10baa3c0bb2f2892022 to your computer and use it in GitHub Desktop.
Save FoobarProtocol/7c1a7eac7d85e10baa3c0bb2f2892022 to your computer and use it in GitHub Desktop.
Very comprehensive dataset preprocessing for solidity smart contracts. Immaculately commented too. You're welcome if you've stumbled upon this #givingbacktothecommunity ; explains the logic behind all decisions that I've made when it comes to that too.
# Importing necessary libraries for data preprocessing and visualization
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import trange
import pandas as pd
import random
import torch
import re
from datasets import load_dataset
from simplet5 import SimpleT5
from sklearn.model_selection import train_test_split
import gc
# Defining separators for different types of smart contracts
SEPERATORS = ('\nabstract contract', '\ncontract', '\nlibrary', '\ninterface', '\nstruct')
# Function to remove extra new lines in the source code
def _remove_extra_new_line(src_in):
src_in = src_in.strip()
src_in = re.sub("(\s)+(\n)", "\n", src_in)
src_in = re.sub("(\n)+", "\n", src_in)
return src_in
# Function to replace Ethereum addresses in the source code with a placeholder
def _replace_addr(src_in):
return re.sub("0x[A-Fa-f0-9]{40}", "YOUR_ADDR", src_in)
# The function `_format_src` takes in a string of source code and applies various regular expression substitutions and string replacements to format the code according to specific rules.
# The `src_in` parameter is a string that represents the source code input.
# This function returns the modified input string `src_in`.
def _format_src(src_in):
# remove extra space before new line
src_in = re.sub("\s+\n", "\n", src_in)
# format the method or class desclaration so each { has exactly one space before
src_in = re.sub(r"(.){", r"\1 {", src_in)
src_in = re.sub("\s+{", r" {", src_in)
src_in = src_in.replace("( ", "(")
src_in = src_in.replace(" )", ")")
src_in = src_in.replace("[ ", "[")
src_in = src_in.replace(" ]", "]")
# Remove unnecessary spaces in method declare
src_in = re.sub("\n\s+external\s ", r" external ", src_in)
src_in = re.sub("\n\s+internal\s", r" internal ", src_in)
src_in = re.sub("\n\s+public\s", r" public ", src_in)
src_in = re.sub("\s+poolOnly\s", r" poolOnly ", src_in)
src_in = re.sub("\s+returns\(", r" returns(", src_in)
# '\nabstract contract', '\ncontract', '\nlibrary', '\ninterface'
src_in = re.sub("}\s+abstract contract ", r"}\nabstract contract ", src_in)
src_in = re.sub("}\s+contract ", r"}\ncontract ", src_in)
src_in = re.sub("}\s+library ", r"}\nlibrary ", src_in)
src_in = re.sub("}\s+interface ", r"}\ninterface ", src_in)
src_in = re.sub("}\s+struct ", r"}\nstruct ", src_in)
src_in = re.sub(";\s+abstract contract ", r";\nabstract contract ", src_in)
src_in = re.sub(";\s+contract ", r";\ncontract ", src_in)
src_in = re.sub(";\s+library ", r";\nlibrary ", src_in)
src_in = re.sub(";\s+interface ", r";\ninterface ", src_in)
src_in = re.sub(";\s+struct ", r";\nstruct ", src_in)
# special, typo "ontract"
src_in = re.sub("}\s+ntract ", r"}\ncontract ", src_in)
src_in = src_in.replace("}contract ", "}\ncontract ")
src_in = src_in.replace("}interface ", "}\ninterface ")
src_in = src_in.replace("}struct ", "}\nstruct ")
return src_in
# The clean function takes a source code string as input, removes extra new lines, replaces addresses, and formats the source code before returning it.
# The 'src' parameter is a string that represents the source code that needs to be cleaned.
# The function eventually returns the cleaned source code.
def clean(src):
src = _remove_extra_new_line(src)
src = _replace_addr(src)
src = _format_src(src)
return src
# The function `_extract_information` extracts information from a given segment using a specified pattern and prefix.
# The `seg` parameter is a string that represents a segment of text or code.
# The pattern parameter is a string that is used to match and extract information from the seg string.
# The prefix parameter is a string that is used to filter the matches found by the regular expression pattern. It is used to remove any unwanted characters or prefixes from the matches before returning them.
# This function returns a list of matches that were found in the input segment (seg) using the provided pattern.
# The matches are then filtered to exclude any strings that start or end with an underscore.
def _extract_information(seg, pattern, prefix):
matches = re.findall(pattern, seg)
if matches:
matches = [s[len(prefix):] for s in matches if not s[len(prefix):].startswith('_') and not s[len(prefix):].endswith('_')]
return matches
# The function `_extract_pub_funcs` extracts public functions from a given segment of code.
# The `seg` paramter is a string that represents a segment of code.
# This function 'returns' a list of public functions that were found in the input segment (seg).
def _extract_pub_funcs(seg):
pattern = "function [A-Za-z0-9_]+\("
prefix = 'function '
return _extract_information(seg, pattern, prefix)
# The function `_extract_constants` extracts constants from a given segment of code..
# The `seg` parameter is a string that represents a segment of code or text.
# This function 'returns' a list of constants that were found in the input segment (seg), pattern (`pattern`) and prefix (`prefix`).
def _extract_constants(seg):
pattern = r"constant [A-Za-z0-9_]+"
prefix = 'constant '
return _extract_information(seg, pattern, prefix)
# The function `_extract_base_parents` extracts the base class and parent classes from a given segment of code.
# The `seg` parameter is a string that represents a segment of code.
# This function 'returns' a tuple containing the base and parents extracted from the input segment.
def _extract_base_parents(seg):
base_with_parents = re.findall("[A-Za-z0-9]+ is [A-Za-z0-9, \n]+ {", seg)
base, parents = None, []
if base_with_parents:
if len(base_with_parents) != 1:
raise ValueError("base_with_parents pattern can only have 1 match")
splits = base_with_parents[0].split(' is ')
if len(splits) != 2:
raise ValueError("cannot have more than 2 splits for base extraction")
base = splits[0]
parents = [p.strip() for p in splits[1][:-2].split(',')]
else:
base_only = re.findall("[A-Za-z0-9]+\s+{", seg)
if base_only:
base = base_only[0].split()[0]
parents = []
return base, parents
# The line `DEFAULT_SOL_VERSION = "pragma solidity ^0.8.0;"` defines a default Solidity version that will be used if the source code does not contain a pragma statement specifying the Solidity version.
# The default version is set to `pragma solidity ^0.8.0;`, which means that the source code is expected to be compatible with Solidity version 0.8.0 or higher.
# After consideration, I believe that it would be smarter to use Solidity version ^0.6.0 instead, so that's what I'm going to do. Not sure if this is for the better or worse.
# Experimentation to a certain extent at this point.
DEFAULT_SOL_VERSION = "pragma solidity ^0.6.0;"
# The function `_prepare_seg_map` takes a list of segments and returns a dictionary mapping each segment's base contract to its parents, constants, public functions, version, and clean source code.
# The `segs` parameter is a list of strings representing segments of Solidity code.
# Each segment represents a separate contract or library in the code
# This funciton returns a dictionary called `seg_map`.
def _prepare_seg_map(segs):
if not segs[0].startswith('pragma solidity'):
segs.insert(0, DEFAULT_SOL_VERSION)
seg_map = {}
for s in segs:
base, parents = _extract_base_parents(s)
if base:
seg_map[base] = {
'parents': parents,
'constants': _extract_constants(s),
'pub_funcs': _extract_pub_funcs(s),
'v': segs[0], # version first line
'clean_src': s,
}
return seg_map
# The function `_split_segments` takes a string `src` and splits it into segments based on certain separators, returning a list of the segments.
# The `src` parameter is a string that represents the source code or text that needs to be split into segments.
# This function returns a list of segments.
def _split_segments(src):
start = 0
segments = []
while True:
# Find the next closest seprator position
next_sep = len(src) + 1
seg_keyword = ""
seg_type = ''
for sep in SEPRATORS:
# print("next_sep", next_sep)
# print("start", start)
cur_src = src[start:]
if sep in cur_src:
sep_ind = cur_src.index(sep)
if sep_ind > 0 and next_sep > sep_ind:
next_sep = sep_ind
seg_keyword = cur_src[sep_ind + len(sep) + 1:].split()[0]
seg_type = sep[1:]
if next_sep > len(src):
if start < len(src) - 1:
segments.append(src[start:].strip())
break
else:
segments.append(src[start:start + next_sep].strip())
start += next_sep + 1
return segments
# The function `_find_ancestors` takes a segment map as input and returns the same map with an additional key-value pair for each segment
# This applies where the key is 'ancestors' and the value is a list of all the ancestors of that segment.
# The `seg_map` parameter is a dictionary that represents a segmentation map.
# Each key in the dictionary represents a segment, and the corresponding value is another dictionary that contains information about the segment, including its parents and ancestors.
# This function returns the updated `seg_map` dictionary with the added key-value pair for each segment's ancestors.
def _find_ancestors(seg_map):
for k in seg_map:
parents = seg_map[k]['parents']
if parents:
ancestors = parents.copy()
idx = 0
while (idx < len(ancestors)):
if ancestors[idx] in seg_map:
for more_parent in seg_map[ancestors[idx]]['parents']:
if more_parent not in ancestors and ancestors != k:
ancestors.append(more_parent)
idx += 1
seg_map[k]['ancestors'] = ancestors return seg_map
return seg_map
def process_single_line(src):
"""Clean text, split to segments, prepare segment map with ancestors."""
src = clean(src)
segs = _split_segments(src)
seg_map = _prepare_seg_map(segs)
seg_map = _find_ancestors(seg_map)
return seg_map
# The function `_get_single_ancestor_metadata` returns a string containing metadata about a given ancestor.
# The param `an` is a string representing the ancestor's name.
# The param `seg_map` is a dictionary that contains information about segments.
# Each segment is represented by a key in the dictionary, and the corresponding value is another dictionary that contains information about the segment, including its parents and ancestors.
# This function returns a string that contains metadata about a given ancestor.
# The metadata includes the ancestor's context, a list of public functions associated with the ancestor, and a list of constants associated with the ancestor.
def _get_single_ancestor_metadata(an, seg_map):
if an not in seg_map:
return ""
pub_func_str = " ".join(seg_map[an]['pub_funcs'])
const_str = " ".join(seg_map[an]['constants'])
return f"// Context: {an} | Functions: {pub_func_str} | Constants: {const_str}"
# The function `_reduce_out_whitespace` removes extra spaces (excluding indentation) and replaces "; " with ";\n" in the given source code.
# The parameter `out_src` is a string that represents the source code that needs to be modified
# This function returns the modified `out_src` string with extra spaces removed and certain characters replaced with newlines.
# The returned string is also stripped of leading and trailing whitespace.
def _reduce_out_whitespace(out_src):
remove extra spaces (ignore identation) and replace "; " with ";\n"
out_src = re.sub("\s+", " ", out_src)
out_src = out_src.replace("; ", ";\n")
out_src = out_src.replace("{ ", "{\n")
out_src = out_src.replace("} ", "}\n")
return out_src.strip()
# The above code is initializing three variables: `my_src`, `my_seg`, and `my_raw`.
# `my_src` is an empty string, `my_seg` is set to `None`, and `my_raw` is an empty string as well.
my_src = ""
my_seg = None
my_raw = ''
# The function `prepare_causal_lm_data` prepares data for a causal language model by extracting relevant code segments from a given source and organizing them into a list.
# The `src` param is a string that represents the source code.
# It is used as input to the `prepare_causal_lm_data` function.
# This function returns a list of strings, where each string represents a piece of code.
def prepare_causal_lm_data(src):
my_src = src
seg_map = process_single_line(src)
my_seg = seg_map
data = []
for k, v in seg_map.items():
# Some headers do not have content
if '{\n' not in v['clean_src']:
continue
s = v['v'] + "\n"
for a in v['ancestors']:
s += _get_single_ancestor_metadata(a, seg_map) + "\n"
raw_src_code = v['clean_src']
my_raw = raw_src_code
header_split_indx = raw_src_code.index('{\n')
s += raw_src_code[:header_split_indx + 1] # include "{"
o = _reduce_out_whitespace(raw_src_code[header_split_indx + 2:])
full_code = s + o
data.append(full_code)
return data
# The function `remove_comments` removes both single-line and multi-line comments from a given string.
# The param `string` is the input string that may contain comments.
# This function returns a string with all comments removed.
def remove_comments(string):
pattern = r"(\".*?\"|\'.*?\')|(/\*.*?\*/|//[^\r\n]*$)"
regex = re.compile(pattern, re.MULTILINE|re.DOTALL)
def _replacer(match):
if match.group(2) is not None:
return ""
else:
return match.group(1)
return regex.sub(_replacer, string)
# The function `get_lengths` takes an example dictionary as input, removes comments from the source code and calcs the length of source code & bytecode
# then adds these lengths to the example dictionary
# The param `example` is a dictionary that contains the following keys:
# This function returns an updated example dictionary with the added keys'sourcecode_len' and 'bytecode_len'.
def get_lengths(example):
code = remove_comments(example['source_code'])
example['sourcecode_len'] = len(code.split())
example['bytecode_len'] = len(HexBytes(example['bytecode']))
return example
# The code below is defining two variables: `HF_DATA_SOURCE` and `DATA_TYPE`.
HF_DATA_SOURCE = "mwritescode/slither-audited-smart-contracts"
DATA_TYPE = "big-multilabel" # change to 'small-plain-text for debugging
# Load train, test and validation splits from HuggingFace datasets and split them into train, test, and validation sets.
train_set = load_dataset(HF_DATA_SOURCE, DATA_TYPE, split="train", revision="main", ignore_verifications=True)
test_set = load_dataset(HF_DATA_SOURCE, DATA_TYPE, split="test", revision="main", ignore_verifications=True)
validation_set = load_dataset(HF_DATA_SOURCE, DATA_TYPE, split="validation", revision="main", ignore_verifications=True)
# Remove 'bytecode' and 'address' columns from the train, test, and validation sets
COLS_TO_REMOVE = ['bytecode', 'address']
train_set = train_set.remove_columns(COLS_TO_REMOVE)
test_set = test_set.remove_columns(COLS_TO_REMOVE)
validation_set = validation_set.remove_columns(COLS_TO_REMOVE)
# Every contract is rated by Slither with a score ranging from 0 to 6.
# The only important thing that we need to know is that the score of '4' means safe/secure.
# Thus, this code parses the train_set, test_set, and validation_set so that they only contain smart contracts that are 'safe' or 'secure'.
# This is done by filtering the train_set, test_set, and validation_set based on the condition that they have a value of 4.
# This simultaneously filters out contracts that are not 'safe' or 'secure'
train_set = train_set.filter(lambda example: example['slither'] == 4)
test_set = test_set.filter(lambda example: example['slither'] == 4)
validation_set = validation_set.filter(lambda example: example['slither'] == 4)
# Once we're done with the above filtering, we can remove the'slither' column from the train, test, and validation sets
train_set = train_set.remove_columns(['slither'])
test_set = test_set.remove_columns(['slither'])
validation_set = validation_set.remove_columns(['slither'])
# This code prints the number of examples in each dataset; this is important for us to note so that we know how many examples are in each dataset
print("Train DS size", len(train_set))
print("Test DS size", len(test_set))
print("Validation DS size", len(validation_set))
# Here, we apply the preprocessing functions to the dataset splits
# This function returns a list of DataFrames, where each DataFrame represents a dataset split.
datasets = []
for split in [train_set, test_set, validation_set]:
split_df = pd.DataFrame(split)
datasets.append(split_df)
# From here we can use the `dataset` column to identify which dataset each example belongs to.
# This is done by using the `split_name` column to identify which dataset each example belongs to.
# We then concatenate the datasets into a single DataFrame.
concatenated = pd.concat([split.assign(dataset=split_name) for split, split_name in zip(datasets, ['train', 'test', 'val'])])
# Now all that's left is to concatenate the concatenated DataFrame into a single DataFrame.
# The difference between the concatenated DataFrame and the original DataFrame is that the concatenated DataFrame contains the text of each example.
concatenated['text'] = concatenated['source_code']
concatenated = concatenated[['text', 'dataset']]
concatenated.head()
@FoobarProtocol
Copy link
Author

There's going to be more appended to this code in the immediate future. More preprocessing in general is needed to ensure that we extract the maximum # of features from any dataset based on solidity (need to also ensure that the format of the input is one that will be effective; can go off script or with a tried + true method /// I think its worth it to go with the latter method)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment