Skip to content

Instantly share code, notes, and snippets.

@afrendeiro
Last active March 13, 2017 11:28
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afrendeiro/b5e97b429ff7363f5574 to your computer and use it in GitHub Desktop.
Save afrendeiro/b5e97b429ff7363f5574 to your computer and use it in GitHub Desktop.
Class to perform map reduce-style operations split in jobs across high-performance computing cluster
class DivideAndSlurm(object):
"""
DivideAndSlurm is a class to handle a map-reduce style submission of jobs to a Slurm cluster.
Add a particula task to the object (though a specific function) and it will divide the input data
into pools, which will be submitted (use the submit() function) in parallel to the cluster.
Tasks can also further process its input in parallel, taking advantage of all processors.
"""
def __init__(self, tmpDir="/fhgfs/scratch/users/user/", logDir="/home/user/logs", queue="shortq", userMail=""):
super(DivideAndSlurm, self).__init__()
self.tasks = dict()
self.name = time.strftime("%Y%m%d%H%M%S", time.localtime())
self.tmpDir = os.path.abspath(tmpDir)
self.logDir = os.path.abspath(logDir)
self.queue = queue
self.userMail = userMail
def __repr__(self):
return "DivideAndSlurm object " + self.name
def __str__(self):
return "DivideAndSlurm object " + self.name
def _slurmHeader(self, jobName, output, queue="shortq", ntasks=1, time="10:00:00", cpusPerTask=16, memPerCpu=2000, nodes=1, userMail=""):
command = """ #!/bin/bash
#SBATCH --partition={0}
#SBATCH --ntasks={1}
#SBATCH --time={2}
#SBATCH --cpus-per-task={3}
#SBATCH --mem-per-cpu={4}
#SBATCH --nodes={5}
#SBATCH --job-name={6}
#SBATCH --output={7}
#SBATCH --mail-type=end
#SBATCH --mail-user={8}
# Activate virtual environment
source /home/arendeiro/venv/bin/activate
# Start running the job
hostname
date
""".format(queue, ntasks, time, cpusPerTask, memPerCpu, nodes, jobName, output, userMail)
return command
def _slurmFooter(self):
command = """
# Deactivate virtual environment
deactivate
# Job end
date
"""
return command
def _slurmSubmitJob(self, jobFile):
"""
Submit command to shell.
"""
command = "sbatch %s" % jobFile
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
return p.communicate()
def _split_data(self, taskName, data, fractions):
"""
Split data in fractions and create pickle objects with them.
"""
chunkify = lambda lst,n: [lst[i::n] for i in xrange(n)]
groups = chunkify(data, fractions)
ids = [taskName + "_" + str(i) for i in xrange(len(groups))]
files = [os.path.join(self.tmpDir, ID) for ID in ids]
# keep track of groups in self
groups = zip(ids, groups, files)
# serialize groups
for i in xrange(len(groups)):
pickle.dump(groups[i][1], # actual group of objects
open(groups[i][2] + ".pickle", 'wb'), # group pickle file
protocol=pickle.HIGHEST_PROTOCOL
)
return groups
def _rm_temps(self, taskNumber):
"""
If self.is_ready(taskNumber), return joined data.
"""
if taskNumber not in self.tasks:
raise KeyError("Task number not in object's tasks.")
if "output" in self.tasks[taskNumber]:
groups = self.tasks[taskNumber]["groups"]
to_rm = list()
[to_rm.append(groups[i][2] + ".pickle") for i in xrange(len(groups))]
[to_rm.append(groups[i][2] + "_count_distances.sh") for i in xrange(len(groups))]
[to_rm.append(groups[i][2] + ".output.pickle") for i in xrange(len(groups))]
for fl in to_rm:
p = subprocess.Popen("rm {0}".format(fl), stdout=subprocess.PIPE, shell=True)
def task(self, data, fractions, bam_file, strand_wise=True, fragment_size=1):
"""
Add task to be performed with data.
"""
now = string.join([time.strftime("%Y%m%d%H%M%S", time.localtime()) str(random.randint(1,1000))], sep="_")
taskName = "task_name_{0}".format(now)
log = taskName + ".log"
# check data is iterable
if type(data) == dict or type(data) == OrderedDict:
data = data.items()
# split data in fractions
groups = self._split_data(taskName, data, fractions)
# make jobs with groups of data
jobs = list()
jobFiles = list()
for i in xrange(len(groups)):
jobFile = groups[i][2] + "_task_name.sh"
input_pickle = groups[i][2] + ".pickle"
output_pickle = groups[i][2] + ".output.pickle"
# assemble command for job
task = " python perform_task_parallel.py {0} {1} {2} ".format(input_pickle, output_pickle, bam_file)
if strand_wise:
task += "--strand-wise "
task += "--fragment-size {0}".format(fragment_size)
# assemble job file
job = self._slurmHeader(groups[i][0], log, queue=self.queue, userMail=self.userMail) + task + self._slurmFooter()
# keep track of jobs and their files
jobs.append(job)
jobFiles.append(jobFile)
# write job file to disk
with open(jobFile, 'w') as handle:
handle.write(textwrap.dedent(job))
# save task in object
taskNumber = len(self.tasks)
self.tasks[taskNumber] = { # don't keep track of data
"name" : taskName,
"groups" : groups,
"jobs" : jobs,
"jobFiles" : jobFiles,
"log" : log
}
# return taskNumber so that it can be used later
return taskNumber
def submit(self, taskNumber):
"""
Submit slurm jobs with each fraction of data.
"""
if taskNumber not in self.tasks:
raise KeyError("Task number not in object's tasks.")
jobIDs = list()
for i in xrange(len(self.tasks[taskNumber]["jobs"])):
output, err = self._slurmSubmitJob(self.tasks[taskNumber]["jobFiles"][i])
jobIDs.append(re.sub("\D", "", output))
self.tasks[taskNumber]["submission_time"] = time.time()
self.tasks[taskNumber]["jobIDs"] = jobIDs
def cancel_jobs(self, taskNumber):
"""
Submit slurm jobs with each fraction of data.
"""
if taskNumber not in self.tasks:
raise KeyError("Task number not in object's tasks.")
if not "jobIDs" in self.tasks[taskNumber]:
return False
for jobID in self.tasks[taskNumber]["jobIDs"]:
command = "scancel %s" % jobID
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
def _is_running(self, taskNumber):
# check if all ids are missing from squeue
p = subprocess.Popen("squeue | unexpand t -t 4 | cut -f 4", stdout=subprocess.PIPE, shell=True)
processes = p.communicate()[0].split("\n")
for ID in self.tasks[taskNumber]["jobIDs"]:
if ID in processes:
return False
return True
def _has_output(self, taskNumber):
# check if all output pickles are produced
outputPickles = [self.tasks[taskNumber]["groups"][i][2] + ".output.pickle" for i in xrange(len(self.tasks[taskNumber]["groups"]))]
for i in outputPickles:
if not os.path.isfile(i):
return False
return True
def is_ready(self, taskNumber):
"""
Check if all submitted jobs have been completed.
"""
if "is_ready" in self.tasks[taskNumber] and self.tasks[taskNumber]["is_ready"]: # if already finished
return True
if "jobIDs" not in self.tasks[taskNumber]: # if not even started
return False
# if is not running and has output
if not self._is_running(taskNumber) and self._has_output(taskNumber):
# save output already
self.tasks[taskNumber]["is_ready"] = True
return True
else:
return False
def collect_output(self, taskNumber):
"""
If self.is_ready(taskNumber), return joined data.
"""
if taskNumber not in self.tasks:
raise KeyError("Task number not in object's tasks.")
if "output" in self.tasks[taskNumber]: # if output is already stored, just return it
return self.tasks[taskNumber]["output"]
if self.is_ready(taskNumber):
# load all pickles into list
groups = self.tasks[taskNumber]["groups"]
outputs = [pickle.load(open(groups[i][2] + ".output.pickle", 'r')) for i in xrange(len(groups))]
# if all are counters, and their elements are counters, sum them
if all([type(outputs[i]) == Counter for i in range(len(outputs))]):
output = reduce(lambda x, y: x + y, outputs) # reduce
if type(output) == Counter:
self.tasks[taskNumber]["output"] = output # store output in object
self._rm_temps(taskNumber) # delete tmp files
return self.tasks[taskNumber]["output"]
else:
raise TypeError("Task is not ready yet.")
def remove_task(self, taskNumber):
"""
Remove task from object.
"""
del self.tasks[taskNumber]
# regions = {
# "H3K27me3" : H3K27me3,
# "H3K4me3" : H3K4me3,
# "H3K27me3_only" : H3K27me3_only,
# "H3K4me3_only" : H3K4me3_only,
# "H3K27me3_H3K4me3" : H3K27me3_H3K4me3
# }
regions = pickle.load(open(os.path.join(args.results_dir, "genomic_regions.pickle"), "r"))
# Initialize Slurm object
slurm = DivideAndSlurm()
tasks = dict()
# Submit tasks for combinations of regions and bam files
for regionName, region in regions.items():
for sampleName, sampleFile in samples.items():
print("Sample " + sampleName, regionName)
exportName = os.path.join(args.results_dir, sampleName + "_" + regionName)
if not os.path.isfile(os.path.join(exportName + ".output-slurm.pickle")):
# Add new task
taskNumber = slurm.count_distances(region, 20, os.path.abspath(sampleFile)) # syntax: data, fractions, bam
# Submit new task
slurm.submit(taskNumber)
# Keep track
tasks[taskNumber] = (sampleName, regionName)
# After submitting, one is free to continue working...
while not all([slurm.is_ready(taskNumber) for taskNumber in tasks.keys()]): # while not all tasks are ready
for taskNumber, (sampleName, regionName) in tasks.items(): # loop through tasks, see if ready
if slurm.is_ready(taskNumber) and taskNumber not in ready: # if yes, collect output and save
print("""\
Task {0} is now ready! {1}, {2}
Time to completion was: {3} minutes.
""".format(taskNumber, sampleName, regionName, int(time.time() - slurm.tasks[taskNumber]["submission_time"])/ 60.))
exportName = os.path.join(args.results_dir, sampleName + "_" + regionName)
dists = slurm.collect_distances(taskNumber)
pickle.dump(dists, open(os.path.join(exportName + ".output-slurm.pickle"), "wb"), protocol=pickle.HIGHEST_PROTOCOL)
ready.append(taskNumber)
#!/usr/env python
from argparse import ArgumentParser
import os
import HTSeq
import cPickle as pickle
import multiprocessing
import parmap
from collections import Counter
import itertools
def task(feature, bam, fragment_size, strand_wise=True):
"""
Computes something with reads present in a single, specific interval. Returns Counter.
feature=HTSeq.GenomicInterval object.
bam=HTSeq.BAM_Reader object.
fragment_size=int.
strand_wise=bool.
"""
counts = Counter()
# Fetch all alignments in feature window
for aln in bam[feature]:
# compute something
counts[aln.iv.start] += 1
return counts
if __name__ == '__main__':
parser = ArgumentParser(
description = 'perform_task_parallel.py',
usage = 'python perform_task_parallel.py <directory> input_pickle '
)
### Global options
# positional arguments
parser.add_argument(dest='input_pickle', type=str, help='Pickle file to load.')
parser.add_argument(dest='output_pickle', type=str, help='Pickle file to save to.')
parser.add_argument(dest='bam_file', type=str, help = 'Bam file.')
# optional arguments
parser.add_argument('--strand-wise', dest='strand_wise', action='store_true')
parser.add_argument('--fragment-size', dest='fragment_size', type=int, default=1)
args = parser.parse_args()
### Read pickle with windows
windows = pickle.load(open(args.input_pickle, 'r')) # dict-like list of tuples
# convert list of tuples to list
windows = [tup[1] for tup in windows]
### Make bam object
bam = HTSeq.BAM_Reader(os.path.abspath(args.bam_file))
### Process in parallel and serialize result
# parallel process and reduce to Counter
output = reduce(
lambda x, y: x + y,
parmap.map(task, windows, bam, args.fragment_size,
strand_wise=args.strand_wise
)
)
### Serialize
pickle.dump(output, open(args.output_pickle, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment