Skip to content

Instantly share code, notes, and snippets.

@shukwong
Last active July 15, 2025 19:56
Show Gist options
  • Save shukwong/92c51ea5afb77632abe0d1ec16705190 to your computer and use it in GitHub Desktop.
Save shukwong/92c51ea5afb77632abe0d1ec16705190 to your computer and use it in GitHub Desktop.
summarize mutations from VCF files, assuming one sample per vcf, or it can take a tab delimited file with same header as vcf
import os
import csv
from collections import defaultdict
import pysam
import argparse
import glob
def get_complementary_mutations():
"""
Returns dictionary of mutation pairs and their complementary mutations
"""
return {
('C>A', 'G>T'): 'C:G>A:T',
('C>G', 'G>C'): 'C:G>G:C',
('C>T', 'G>A'): 'C:G>T:A',
('T>A', 'A>T'): 'T:A>A:T',
('T>C', 'A>G'): 'T:A>C:G',
('T>G', 'A>C'): 'T:A>G:C'
}
def analyze_tab_file(tab_file, sample_name=None, find_clusters=False):
"""
Analyze a tab-delimited file with VCF-like columns
"""
stats = {
'total_variants': 0,
'total_snvs': 0,
'total_indels': 0,
'mutation_types': defaultdict(int),
'mutation_pairs': defaultdict(int)
}
comp_mutations = get_complementary_mutations()
variants = [] if find_clusters else None
try:
with open(tab_file, 'r') as f:
# Skip any comment lines and find the header
header = None
for line in f:
if line.startswith('#CHROM'):
header = line.strip().split('\t')
break
if not header:
print(f"Error: Could not find #CHROM header line in {tab_file}")
return None if not find_clusters else (None, None)
# Find required column indices
try:
ref_idx = header.index('REF')
alt_idx = header.index('ALT')
chrom_idx = header.index('CHROM') if 'CHROM' in header else 0
pos_idx = header.index('POS') if 'POS' in header else 1
# Get sample name from the last column if not provided
if not sample_name:
sample_name = header[-1]
except ValueError as e:
print(f"Error: Required column (REF or ALT) not found in {tab_file}")
return None if not find_clusters else (None, None)
stats['sample_name'] = sample_name
# Process each variant
for line in f:
if line.startswith('#'):
continue
fields = line.strip().split('\t')
if len(fields) <= max(ref_idx, alt_idx):
continue
ref = fields[ref_idx]
alt = fields[alt_idx]
chrom = fields[chrom_idx]
pos = int(fields[pos_idx])
stats['total_variants'] += 1
if len(ref) == len(alt) == 1: # SNV
stats['total_snvs'] += 1
mutation = f"{ref}>{alt}"
stats['mutation_types'][mutation] += 1
# Count mutation pairs
for pair, pair_name in comp_mutations.items():
if mutation in pair:
stats['mutation_pairs'][pair_name] += 1
break
else: # indel
stats['total_indels'] += 1
if find_clusters:
variants.append({'sample': sample_name, 'chrom': chrom, 'pos': pos, 'ref': ref, 'alt': alt})
except Exception as e:
print(f"Error processing {tab_file}: {str(e)}")
return None if not find_clusters else (None, None)
# Calculate proportions
total_snvs = stats['total_snvs']
stats['proportions'] = {
mut_type: count/total_snvs if total_snvs > 0 else 0
for mut_type, count in stats['mutation_types'].items()
}
stats['pair_proportions'] = {
pair_name: count/total_snvs if total_snvs > 0 else 0
for pair_name, count in stats['mutation_pairs'].items()
}
return stats if not find_clusters else (stats, variants)
def analyze_vcf(vcf_path, find_clusters=False):
"""
Analyze a VCF file for mutation statistics and optionally collect variant positions for clustering.
"""
stats = {
'total_variants': 0,
'total_snvs': 0,
'total_indels': 0,
'mutation_types': defaultdict(int),
'mutation_pairs': defaultdict(int)
}
comp_mutations = get_complementary_mutations()
variants = [] if find_clusters else None
sample_name = os.path.basename(vcf_path).split('.')[0]
try:
vcf = pysam.VariantFile(vcf_path)
stats['sample_name'] = sample_name
for rec in vcf:
ref = rec.ref
alt = rec.alts[0] if rec.alts else None
chrom = str(rec.chrom)
pos = int(rec.pos)
if not alt:
continue
stats['total_variants'] += 1
if len(ref) == len(alt) == 1:
stats['total_snvs'] += 1
mutation = f"{ref}>{alt}"
stats['mutation_types'][mutation] += 1
for pair, pair_name in comp_mutations.items():
if mutation in pair:
stats['mutation_pairs'][pair_name] += 1
break
else:
stats['total_indels'] += 1
if find_clusters:
variants.append({'sample': sample_name, 'chrom': chrom, 'pos': pos, 'ref': ref, 'alt': alt})
except Exception as e:
print(f"Error processing {vcf_path}: {str(e)}")
return None if not find_clusters else (None, None)
total_snvs = stats['total_snvs']
stats['proportions'] = {
mut_type: count/total_snvs if total_snvs > 0 else 0
for mut_type, count in stats['mutation_types'].items()
}
stats['pair_proportions'] = {
pair_name: count/total_snvs if total_snvs > 0 else 0
for pair_name, count in stats['mutation_pairs'].items()
}
return stats if not find_clusters else (stats, variants)
def cluster_variants(variants, distance=50000):
"""
Group variants into clusters if they are within 'distance' bp on the same chromosome for the same sample.
Returns a list of dicts with cluster_id and variant info, and a dict of sample -> number of clusters.
"""
from collections import defaultdict, Counter
clustered = []
cluster_counts = defaultdict(int)
# Group by sample and chromosome
by_sample_chrom = defaultdict(list)
for v in variants:
by_sample_chrom[(v['sample'], v['chrom'])].append(v)
for (sample, chrom), var_list in by_sample_chrom.items():
var_list.sort(key=lambda x: x['pos'])
cluster_id = 1
cluster = [var_list[0]]
for prev, curr in zip(var_list, var_list[1:]):
if curr['pos'] - prev['pos'] <= distance:
cluster.append(curr)
else:
# Assign cluster ID to all in cluster
for v in cluster:
clustered.append({**v, 'cluster_id': f'{sample}_{chrom}_{cluster_id}'})
cluster_counts[sample] += 1
cluster_id += 1
cluster = [curr]
# Last cluster
for v in cluster:
clustered.append({**v, 'cluster_id': f'{sample}_{chrom}_{cluster_id}'})
cluster_counts[sample] += 1
# Filter clusters with <2 variants
# Count variants per cluster_id
cluster_id_counts = Counter([v['cluster_id'] for v in clustered])
# Only keep clusters with >=2 variants
valid_cluster_ids = {cid for cid, count in cluster_id_counts.items() if count >= 2}
filtered_clustered = [v for v in clustered if v['cluster_id'] in valid_cluster_ids]
# Count number of valid clusters per sample
sample_to_cluster_count = defaultdict(int)
for cid in valid_cluster_ids:
sample = cid.split('_')[0]
sample_to_cluster_count[sample] += 1
return filtered_clustered, sample_to_cluster_count
def process_files(input_files, output_file, tab_file=None, sample_name=None, find_clusters=False, cluster_output=None, parental_origin_dir=None, cluster_size=46415):
"""
Process VCF files and/or tab file and output results to CSV
"""
results = []
all_mutation_types = set()
mutation_pairs = list(get_complementary_mutations().values())
all_variants = [] if find_clusters else None
sample_to_cluster_count = {}
parental_origin_counts = {}
# Process tab file if provided
if tab_file:
print(f"Processing tab file: {tab_file}")
if find_clusters:
tab_stats, variants = analyze_tab_file(tab_file, sample_name, find_clusters)
if tab_stats:
all_mutation_types.update(tab_stats['mutation_types'].keys())
results.append(tab_stats)
all_variants.extend(variants)
else:
tab_stats = analyze_tab_file(tab_file, sample_name, find_clusters)
if tab_stats:
all_mutation_types.update(tab_stats['mutation_types'].keys())
results.append(tab_stats)
# Process VCF files
vcf_files = []
if input_files: # Only process VCF files if input_files is provided
if os.path.isdir(input_files):
vcf_files.extend(glob.glob(os.path.join(input_files, "*.vcf.gz")))
vcf_files.extend(glob.glob(os.path.join(input_files, "*.vcf")))
elif input_files.endswith(('.vcf', '.vcf.gz')):
vcf_files = [input_files]
if vcf_files:
print(f"Found {len(vcf_files)} VCF files to process")
for vcf_path in vcf_files:
print(f"Processing {vcf_path}...")
if find_clusters:
stats, variants = analyze_vcf(vcf_path, find_clusters)
if stats:
all_mutation_types.update(stats['mutation_types'].keys())
results.append(stats)
all_variants.extend(variants)
else:
stats = analyze_vcf(vcf_path, find_clusters)
if stats:
all_mutation_types.update(stats['mutation_types'].keys())
results.append(stats)
if not results:
print("No data was processed successfully")
return
# If clustering, perform clustering and write cluster file
if find_clusters and all_variants:
clustered, cluster_counts = cluster_variants(all_variants, distance=cluster_size)
# Build parent-of-origin lookup if needed
po_lookup = dict()
if parental_origin_dir:
for stats in results:
sample = stats['sample_name']
fname = os.path.join(parental_origin_dir, f"{sample}.parental_origin.tab")
if os.path.exists(fname):
with open(fname) as f:
header = None
po_idx = None
chrom_idx = pos_idx = ref_idx = alt_idx = variantid_idx = None
for line in f:
if line.startswith('#'):
continue
if header is None:
header = line.strip().split('\t')
# Try to get indices for required columns
try:
if 'VariantID' in header:
variantid_idx = header.index('VariantID')
chrom_idx = header.index('CHROM') if 'CHROM' in header else 0
pos_idx = header.index('POS') if 'POS' in header else 1
ref_idx = header.index('REF') if 'REF' in header else None
alt_idx = header.index('ALT') if 'ALT' in header else None
po_idx = header.index('ParentalOrigin')
except ValueError:
print(f"Warning: Required columns not found in {fname}")
break
continue
fields = line.strip().split('\t')
if po_idx is None or po_idx >= len(fields):
continue
if variantid_idx is not None and variantid_idx < len(fields):
# Parse VariantID: CHROM:POS:REF:ALT
variant_id = fields[variantid_idx]
parts = variant_id.split(':')
if len(parts) == 4:
chrom, pos, ref, alt = parts[0], int(parts[1]), parts[2], parts[3]
key = (sample, chrom, pos, ref, alt)
po_lookup[key] = fields[po_idx]
elif None not in (chrom_idx, pos_idx, ref_idx, alt_idx) and max(chrom_idx, pos_idx, ref_idx, alt_idx) < len(fields):
key = (sample, fields[chrom_idx], int(fields[pos_idx]), fields[ref_idx], fields[alt_idx])
po_lookup[key] = fields[po_idx]
# Write cluster file
cluster_to_po = dict()
sample_cluster_type_counts = {}
if parental_origin_dir:
from collections import defaultdict
cluster_to_po_vals = defaultdict(list)
for v in clustered:
key = (v['sample'], v['chrom'], v['pos'], v['ref'], v['alt'])
po = po_lookup.get(key, 'NA')
cluster_to_po_vals[v['cluster_id']].append(po)
for cid, po_list in cluster_to_po_vals.items():
po_set = set([p.lower() for p in po_list if p != 'NA'])
if not po_set or po_set == {'nd'}:
cluster_to_po[cid] = 'NotDetermined'
elif po_set <= {'maternal', 'nd'}:
cluster_to_po[cid] = 'maternal'
elif po_set <= {'paternal', 'nd'}:
cluster_to_po[cid] = 'paternal'
elif 'maternal' in po_set and 'paternal' in po_set:
cluster_to_po[cid] = 'mixed'
else:
cluster_to_po[cid] = 'NotDetermined'
# Count clusters by type per sample
sample_cluster_type_counts = defaultdict(lambda: {'maternal': 0, 'paternal': 0, 'mixed': 0, 'NotDetermined': 0})
for cid, po_type in cluster_to_po.items():
sample = cid.split('_')[0]
sample_cluster_type_counts[sample][po_type] += 1
# Write cluster output file (all writing inside this block)
with open(cluster_output, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
header = ['Sample', 'Cluster_ID', 'Chromosome', 'Position', 'Reference', 'Alternative']
if parental_origin_dir:
header.append('variant_parentOfOrigin')
header.append('cluster_parentOfOrigin')
writer.writerow(header)
for v in clustered:
row = [v['sample'], v['cluster_id'], v['chrom'], v['pos'], v['ref'], v['alt']]
if parental_origin_dir:
key = (v['sample'], v['chrom'], v['pos'], v['ref'], v['alt'])
po = po_lookup.get(key, 'NA')
row.append(po)
row.append(cluster_to_po.get(v['cluster_id'], 'NA'))
writer.writerow(row)
sample_to_cluster_count = cluster_counts
# Save for summary output
if parental_origin_dir:
global _sample_cluster_type_counts
_sample_cluster_type_counts = sample_cluster_type_counts
# Parental origin counts
if parental_origin_dir:
for stats in results:
sample = stats['sample_name']
fname = os.path.join(parental_origin_dir, f"{sample}.parental_origin.tab")
paternal = maternal = nd = 0
if os.path.exists(fname):
with open(fname) as f:
header = None
for line in f:
if line.startswith('#'):
continue
if header is None:
header = line.strip().split('\t')
try:
po_idx = header.index('ParentalOrigin')
except ValueError:
print(f"Warning: ParentalOrigin column not found in {fname}")
break
continue
fields = line.strip().split('\t')
if len(fields) <= po_idx:
continue
val = fields[po_idx].strip().lower()
if val == 'paternal':
paternal += 1
elif val == 'maternal':
maternal += 1
elif val == 'nd':
nd += 1
parental_origin_counts[sample] = {'Paternal': paternal, 'Maternal': maternal, 'ND': nd}
# Count number of variants, SNVs, and indels in clusters per sample
sample_cluster_variant_counts = {}
for v in clustered:
sample = v['sample']
ref = v['ref']
alt = v['alt']
if sample not in sample_cluster_variant_counts:
sample_cluster_variant_counts[sample] = {'variants': 0, 'snvs': 0, 'indels': 0}
sample_cluster_variant_counts[sample]['variants'] += 1
if len(ref) == len(alt) == 1:
sample_cluster_variant_counts[sample]['snvs'] += 1
else:
sample_cluster_variant_counts[sample]['indels'] += 1
# Sort mutation types for consistent column ordering
all_mutation_types = sorted(all_mutation_types)
# Prepare CSV headers
headers = [
'Sample',
'Total_Variants',
'Total_Variants_NoCluster',
'Total_SNVs',
'Total_SNVs_NoCluster',
'Total_Indels',
'Total_Indels_NoCluster'
]
for mut_type in all_mutation_types:
headers.extend([f'{mut_type}_count', f'{mut_type}_percent'])
for pair in mutation_pairs:
headers.extend([f'{pair}_count', f'{pair}_percent'])
if find_clusters:
headers.append('Num_Clusters')
if parental_origin_dir:
headers.extend(['maternal_clusters', 'paternal_clusters', 'mixed_clusters', 'NotDetermined_clusters'])
if parental_origin_dir:
headers.extend(['Paternal', 'Maternal', 'ND'])
# Add new columns for phased proportions and estimated counts
headers.extend([
'Proportion_Paternal_Phased',
'Proportion_Maternal_Phased',
'Estimated_Paternal_Variants',
'Estimated_Maternal_Variants'
])
# Write to CSV
with open(output_file, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers)
for stats in results:
sample = stats['sample_name']
total_variants = stats['total_variants']
total_snvs = stats['total_snvs']
total_indels = stats['total_indels']
# Compute no-cluster values if clustering was done
if find_clusters:
cluster_counts_dict = sample_cluster_variant_counts.get(sample, {'variants': 0, 'snvs': 0, 'indels': 0})
variants_no_cluster = total_variants - cluster_counts_dict['variants']
snvs_no_cluster = total_snvs - cluster_counts_dict['snvs']
indels_no_cluster = total_indels - cluster_counts_dict['indels']
else:
variants_no_cluster = total_variants
snvs_no_cluster = total_snvs
indels_no_cluster = total_indels
row = [
sample,
total_variants,
variants_no_cluster,
total_snvs,
snvs_no_cluster,
total_indels,
indels_no_cluster
]
for mut_type in all_mutation_types:
count = stats['mutation_types'].get(mut_type, 0)
percentage = stats['proportions'].get(mut_type, 0)
row.extend([count, f"{percentage:.4f}"])
for pair in mutation_pairs:
count = stats['mutation_pairs'].get(pair, 0)
percentage = stats['pair_proportions'].get(pair, 0)
row.extend([count, f"{percentage:.4f}"])
if find_clusters:
row.append(sample_to_cluster_count.get(stats['sample_name'], 0))
if parental_origin_dir:
ccounts = _sample_cluster_type_counts.get(stats['sample_name'], {'maternal': 0, 'paternal': 0, 'mixed': 0, 'NotDetermined': 0})
row.extend([
ccounts['maternal'],
ccounts['paternal'],
ccounts['mixed'],
ccounts['NotDetermined']
])
if parental_origin_dir:
poc = parental_origin_counts.get(stats['sample_name'], {'Paternal': 0, 'Maternal': 0, 'ND': 0})
row.extend([poc['Paternal'], poc['Maternal'], poc['ND']])
# Calculate phased proportions and estimated counts
phased_total = poc['Paternal'] + poc['Maternal']
if phased_total > 0:
prop_paternal = poc['Paternal'] / phased_total
prop_maternal = poc['Maternal'] / phased_total
else:
prop_paternal = 0.0
prop_maternal = 0.0
est_paternal = total_variants * prop_paternal
est_maternal = total_variants * prop_maternal
row.extend([
f"{prop_paternal:.4f}",
f"{prop_maternal:.4f}",
f"{est_paternal:.2f}",
f"{est_maternal:.2f}"
])
writer.writerow(row)
print(f"Results have been written to {output_file}")
if find_clusters:
print(f"Clustered mutations have been written to {cluster_output}")
def main():
parser = argparse.ArgumentParser(description='Analyze VCF files and/or tab-delimited files for mutation statistics')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-i', '--input',
help='Input VCF file or directory containing VCF files')
group.add_argument('-t', '--tab',
help='Input tab-delimited file with VCF-like columns')
parser.add_argument('-o', '--output', required=True,
help='Output CSV file path')
parser.add_argument('-s', '--sample',
help='Sample name for tab-delimited file (optional)')
parser.add_argument('--find-clusters', action='store_true',
help='Enable detection of clustered mutations within a specified cluster size (default: 46415 bp)')
parser.add_argument('--cluster-output',
help='Output file path for clustered mutations (if --find-clusters is set)')
parser.add_argument('--parental-origin-dir',
help='Directory containing <sampleID>.parental_origin.tab files for parent-of-origin counts (optional)')
parser.add_argument('--cluster-size', type=int, default=46415,
help='Cluster size in base pairs for defining clustered mutations (default: 46415)')
args = parser.parse_args()
# Check if input files exist
if args.input and not os.path.exists(args.input):
print(f"Error: Input path {args.input} does not exist!")
return
if args.tab and not os.path.exists(args.tab):
print(f"Error: Tab file {args.tab} does not exist!")
return
if args.parental_origin_dir and not os.path.isdir(args.parental_origin_dir):
print(f"Error: Parental origin directory {args.parental_origin_dir} does not exist!")
return
process_files(args.input, args.output, args.tab, args.sample, args.find_clusters, args.cluster_output, args.parental_origin_dir, args.cluster_size)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment