Last active
February 2, 2018 16:32
-
-
Save zurk/09952f90b395653b87e8556f46873914 to your computer and use it in GitHub Desktop.
Example. How you can get random walks from UAST via sourced-ml.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from sourced.ml.transformers import Engine, UastExtractor, UastDeserializer, Transformer, HeadFiles | |
from sourced.ml.utils import create_engine | |
from sourced.ml.algorithms.uast_struct_to_bag import Uast2RandomWalks | |
# Sourced-ml is based on transformers. | |
# You create a pipeline from transformers and then run it. | |
# We do not have a transformer, that just returns random walks, so we need to create it first. | |
class UastRandomWalk2Transformer(Transformer): | |
def __init__(self, p_explore_neighborhood=0.5, q_leave_neighborhood=0.5, | |
n_walks=10, n_steps=80, seed=42, **kwargs): | |
super().__init__(**kwargs) | |
# Uast2RandomWalks -- main class for Random Walk algorithm | |
# You can find parameters description inside class. | |
self.uast2walks = Uast2RandomWalks(p_explore_neighborhood=p_explore_neighborhood, | |
q_leave_neighborhood=q_leave_neighborhood, | |
n_walks=n_walks, | |
n_steps=n_steps, | |
seed=seed) | |
def __call__(self, rows): | |
return rows.flatMap(self._process_row) | |
def _process_row(self, row): | |
for walk in self.uast2walks(row.uast): | |
yield walk | |
# Put your dataset directory path here | |
repos_dir = os.path.abspath("../dataset/") | |
# You can choose from languages, which are supported by bblfsh: | |
# https://doc.bblf.sh/languages.html (If you going to use Roles instead of | |
# internal one you should check that Annotations cell have a ✓ sign) | |
languages = ["Python"] | |
# Next variable means you just store folders with repositories inside `repos_dir`. | |
# You can also use `siva` format (https://github.com/src-d/go-siva), | |
# but I do not think you need it now. | |
repository_format = "standard" | |
# Create engine | |
engine = create_engine(session_name="walks", | |
repositories=repos_dir, | |
repository_format=repository_format) | |
# Build pipeline | |
pipeline = Engine(engine) \ | |
.link(HeadFiles()) \ | |
.link(UastExtractor(languages=languages)) \ | |
.link(UastDeserializer())\ | |
.link(UastRandomWalk2Transformer()) \ | |
# and execute it | |
walks = pipeline \ | |
.execute() \ | |
.collect() | |
print(walks) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment