Created
March 26, 2016 11:51
-
-
Save bencharb/2ab44e028fcab62b209a to your computer and use it in GitHub Desktop.
safe overwrite file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Move a file to a backup, if the file exists. If context manager statement fails, then restore the backup. | |
''' | |
class FakeException(Exception): | |
pass | |
def create_backup_path(path, backup_ext='.bak'): | |
backup_path = path | |
while backup_path == path or os.path.exists(backup_path): | |
backup_path = backup_path+backup_ext | |
return backup_path | |
@contextlib.contextmanager | |
def overwrite_file(path, backup_ext='.bak'): | |
''' | |
Move a file to a backup, if the file exists. If context manager statement fails, then restore the backup. | |
''' | |
if not os.path.exists(path): | |
yield | |
else: | |
backup_path = create_backup_path(path, backup_ext=backup_ext) | |
try: | |
shutil.move(path, backup_path) | |
yield | |
except: | |
shutil.move(backup_path, path) | |
raise | |
else: | |
os.remove(backup_path) | |
def print_status(testfunc): | |
def decorator(*args, **kwargs): | |
try: | |
result = testfunc(*args, **kwargs) | |
except AssertionError: | |
print 'failed: %s' % testfunc.func_name | |
raise | |
else: | |
print 'passed: %s' % testfunc.func_name | |
return decorator | |
class TestOverwriteFile(object): | |
filepath = '/tmp/testfile.txt' | |
filepathbackup = '/tmp/testfile.txt.bak' | |
data1 = 'data1' | |
data2 = 'data2' | |
def _deletefiles(self): | |
for p in (self.filepath, self.filepathbackup,): | |
if os.path.exists(p): | |
os.remove(p) | |
def setup(self): | |
self._deletefiles() | |
with open(self.filepath, 'w') as fout: | |
fout.write(self.data1) | |
def _readfile(self, path): | |
with open(path, 'r') as fin: | |
return fin.read() | |
def file_exists(self, f): | |
return os.path.exists(f) | |
@print_status | |
def test_basic(self): | |
self.setup() | |
data_is_original = lambda: self._readfile(self.filepath) == self.data1 | |
assert data_is_original() | |
assert not os.path.exists(self.filepathbackup) | |
with overwrite_file(self.filepath): | |
with open(self.filepath, 'w') as fout: | |
fout.write(self.data2) | |
assert os.path.exists(self.filepathbackup) | |
assert not data_is_original() | |
assert not os.path.exists(self.filepathbackup) | |
assert not data_is_original() | |
@print_status | |
def test_fail(self): | |
self.setup() | |
data_is_original = lambda: self._readfile(self.filepath) == self.data1 | |
assert data_is_original() | |
try: | |
with overwrite_file(self.filepath): | |
with open(self.filepath, 'w') as fout: | |
fout.write(self.data2) | |
raise FakeException('ohno') | |
except FakeException: | |
assert data_is_original() | |
assert not os.path.exists(self.filepathbackup) | |
def run(self): | |
self.test_basic() | |
self.test_fail() | |
def __call__(self): | |
return self.run() | |
def test_overwritefile(): | |
tof = TestOverwriteFile() | |
tof() | |
test_overwritefile() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment