Skip to content

Instantly share code, notes, and snippets.

@tag1216
Created May 7, 2017 07:45
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tag1216/2358da45274a6ac781c33b0cc980fa14 to your computer and use it in GitHub Desktop.
Save tag1216/2358da45274a6ac781c33b0cc980fa14 to your computer and use it in GitHub Desktop.
巨大ファイルのソート
import heapq
import os
import re
from argparse import ArgumentParser
from contextlib import contextmanager
from operator import itemgetter
from tempfile import TemporaryDirectory, mktemp
import sys
from typing import IO, Callable, List
def large_sort(input_file: IO, output_file: IO, key: Callable=None, reverse: bool=False, limit_chars: int=1024*1024*64):
with TemporaryDirectory() as tmp_dir:
for lines in _read_parts(input_file, limit_chars):
lines = sorted(lines, key=key, reverse=reverse)
_write_part(lines, tmp_dir)
with _open_tmp_files(tmp_dir) as tmp_files:
for row in heapq.merge(*tmp_files, key=key, reverse=reverse):
output_file.write(row)
def _read_parts(input_file, limit_chars):
lines = input_file.readlines(limit_chars)
while lines:
yield lines
lines = input_file.readlines(limit_chars)
def _write_part(lines, tmp_dir):
tmp_filename = mktemp(dir=tmp_dir)
with open(tmp_filename, "w") as tmp_file:
tmp_file.writelines(lines)
return tmp_filename
@contextmanager
def _open_tmp_files(tmp_dir):
filenames = os.listdir(tmp_dir)
files = [open(os.path.join(tmp_dir, filename), "r") for filename in filenames]
try:
yield files
finally:
for file in files:
file.close()
def key_func(keys: List[str]=None, separator: str=" "):
if not keys:
return None
pattern = re.compile("([0-9]+)(n?)")
getters = []
for key in keys:
m = pattern.match(key)
column = int(m.group(1)) - 1
number = bool(m.group(2))
getter = _itemgetter_int(column) if number else itemgetter(column)
getters.append(getter)
def func(row):
values = row.strip("\n").split(separator)
return [f(values) for f in getters]
return func
def _itemgetter_int(index):
def f(x):
return int(x[index])
return f
def _parse_args():
parser = ArgumentParser()
parser.add_argument("-t", "--field-separator", dest="separator", default=" ", help="フィールド区切り文字を指定する")
parser.add_argument("-k", "--key", dest="keys", action="append", help="ソート対象フィールド指定する 例) -k 1 -k 2n")
parser.add_argument("-r", "--reverse", dest="reverse", action="store_true", default=False, help="降順ソート")
parser.add_argument("-l", "--limit", dest="limit", type=int, default=1024*1024*64, help="一度に読み込む文字数の制限")
parser.add_argument("file", nargs="?", help="入力ファイル")
return parser.parse_args()
def main():
args = _parse_args()
file = open(args.file, "r") if args.file else sys.stdin
try:
large_sort(file, sys.stdout,
key=key_func(args.keys, args.separator),
reverse=args.reverse,
limit_chars=args.limit)
finally:
if args.file:
file.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment