Last active
February 11, 2022 23:07
-
-
Save 1f604/7516d2f3b32df4aeb1b40f06d276943f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding=utf-8 | |
# License: Public domain (CC0) | |
# Isaac Turner 2016/12/05 | |
# 1f604 2022/02/11 | |
from __future__ import print_function | |
import difflib | |
import re | |
_hdr_pat = re.compile("^@@ -(\d+),?(\d+)? \+(\d+),?(\d+)? @@$") | |
def fix_missing_newline(s): | |
if not s: | |
s = "\n" | |
if s[-1] != '\n': | |
s = s + '\n' | |
return s | |
def make_patch(a,b): | |
""" | |
Get unified string diff between two strings. Trims top two lines. | |
Returns empty string if strings are identical. | |
""" | |
diffs = difflib.unified_diff(a.splitlines(True),b.splitlines(True),n=0) | |
try: _,_ = next(diffs),next(diffs) | |
except StopIteration: pass | |
# diffs = list(diffs); print(diffs) | |
return ''.join(diffs) | |
def apply_patch(old, patch): | |
""" | |
Apply unified diff patch to string old to recover newer string. | |
If revert is True, treat old as the newer string, recover older string. | |
""" | |
old = old.splitlines(True) | |
patch = patch.splitlines(True) | |
result = '' | |
patch_pointer = 0 | |
old_current_pointer = 0 | |
allowed_line_starts = "@+-" | |
#for char in allowed_line_starts: | |
# print("allowed:", char, ord(char)) | |
while patch_pointer < len(patch) and patch[patch_pointer].startswith(("---","+++")): | |
patch_pointer += 1 # skip header lines | |
while patch_pointer < len(patch): | |
# get starting line number from hunk header | |
m = _hdr_pat.match(patch[patch_pointer]) | |
if not m: | |
raise Exception("Cannot process diff") | |
patch_pointer += 1 | |
old_start_pointer = int(m.group(1))-1 + (m.group(2) == '0') | |
result += ''.join(old[old_current_pointer:old_start_pointer]) | |
old_current_pointer = old_start_pointer | |
# go through hunk | |
while patch_pointer < len(patch) and patch[patch_pointer][0] != '@': | |
if patch_pointer + 1 < len(patch) and patch[patch_pointer+1][0] not in allowed_line_starts: | |
print("ERROR: line does not begin with expected symbol:", ord(patch[patch_pointer+1][0]), patch[patch_pointer+1]) | |
exit(1) | |
line = patch[patch_pointer] | |
patch_pointer += 1 | |
assert(len(line) > 0) | |
assert(line[0] in allowed_line_starts) | |
if line[0] == '+': | |
result += line[1:] | |
else: | |
old_current_pointer += 1 | |
result += ''.join(old[old_current_pointer:]) | |
return result | |
# | |
# Testing | |
# | |
import random | |
import string | |
import traceback | |
import sys | |
import codecs | |
def test_diff(a,b): | |
a = fix_missing_newline(a) | |
b = fix_missing_newline(b) | |
mp = make_patch(a,b) | |
try: | |
assert apply_patch(a,mp) == b | |
except Exception as e: | |
print("=== a ===") | |
print([a]) | |
print("=== b ===") | |
print([b]) | |
print("=== mp ===") | |
print([mp]) | |
print("=== a->b ===") | |
print(apply_patch(a,mp)) | |
traceback.print_exc() | |
sys.exit(-1) | |
def randomly_interleave(*args): | |
""" Randomly interleave multiple lists/iterators """ | |
iters = [iter(x) for x in args] | |
while iters: | |
i = random.randrange(len(iters)) | |
try: | |
yield next(iters[i]) | |
except StopIteration: | |
# swap empty iterator to end and remove | |
iters[i],iters[-1] = iters[-1],iters[i] | |
iters.pop() | |
def rand_ascii(): | |
return random.choice(string.printable) | |
def rand_unicode(): | |
a = u"\\u%04x" % random.randrange(0x10000) | |
# return a.decode('utf-8') | |
return str(codecs.encode(a, 'utf-8')) | |
def generate_test(nlines=10,linelen=10,randchar=rand_ascii): | |
""" | |
Generate two strings with approx `nlines` lines, which share approx half their | |
lines. Then run the diff/patch test unit with the two strings. | |
Lines are random characters and may include newline / linefeeds. | |
""" | |
aonly,bonly,nshared = (random.randrange(nlines) for _ in range(3)) | |
a = [ ''.join([randchar() for _ in range(linelen)]) for _ in range(aonly)] | |
b = [ ''.join([randchar() for _ in range(linelen)]) for _ in range(bonly)] | |
ab = [ ''.join([randchar() for _ in range(linelen)]) for _ in range(nshared)] | |
a = randomly_interleave(a,ab) | |
b = randomly_interleave(b,ab) | |
test_diff(''.join(a),''.join(b)) | |
def std_tests(): | |
test_diff("asdf\nhamster\nmole\nwolf\ndog\ngiraffe", | |
"asdf\nhampster\nmole\nwolf\ndooog\ngiraffe\n") | |
test_diff("asdf\nhamster\nmole\nwolf\ndog\ngiraffe", | |
"hampster\nmole\nwolf\ndooog\ngiraffe\n") | |
test_diff("hamster\nmole\nwolf\ndog", | |
"asdf\nhampster\nmole\nwolf\ndooog\ngiraffe\n") | |
test_diff("", "") | |
test_diff("", "asdf\nasf") | |
test_diff("asdf\nasf","xxx") | |
# Things can get nasty, we need to be able to handle any input | |
# see https://docs.python.org/3/library/stdtypes.html | |
test_diff("\x0c", "\n\r\n") | |
test_diff("\x1c\v", "\f\r\n") | |
def main(): | |
print("Testing...") | |
std_tests() | |
print("Testing random ASCII...") | |
for _ in range(50): generate_test(50,50,rand_ascii) | |
print("Testing random unicode...") | |
for _ in range(50): generate_test(50,50,rand_unicode) | |
print("Passed ✓") | |
if __name__ == '__main__': main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I renamed some variables in the apply_patch function to make it easier to understand.
Also removed end of line handling. I think the code works as long as the input files are newline-terminated.