Skip to content

Instantly share code, notes, and snippets.

@bencharb
Created March 26, 2016 11:51
Show Gist options
  • Save bencharb/2ab44e028fcab62b209a to your computer and use it in GitHub Desktop.
Save bencharb/2ab44e028fcab62b209a to your computer and use it in GitHub Desktop.
safe overwrite file
'''
Move a file to a backup, if the file exists. If context manager statement fails, then restore the backup.
'''
class FakeException(Exception):
pass
def create_backup_path(path, backup_ext='.bak'):
backup_path = path
while backup_path == path or os.path.exists(backup_path):
backup_path = backup_path+backup_ext
return backup_path
@contextlib.contextmanager
def overwrite_file(path, backup_ext='.bak'):
'''
Move a file to a backup, if the file exists. If context manager statement fails, then restore the backup.
'''
if not os.path.exists(path):
yield
else:
backup_path = create_backup_path(path, backup_ext=backup_ext)
try:
shutil.move(path, backup_path)
yield
except:
shutil.move(backup_path, path)
raise
else:
os.remove(backup_path)
def print_status(testfunc):
def decorator(*args, **kwargs):
try:
result = testfunc(*args, **kwargs)
except AssertionError:
print 'failed: %s' % testfunc.func_name
raise
else:
print 'passed: %s' % testfunc.func_name
return decorator
class TestOverwriteFile(object):
filepath = '/tmp/testfile.txt'
filepathbackup = '/tmp/testfile.txt.bak'
data1 = 'data1'
data2 = 'data2'
def _deletefiles(self):
for p in (self.filepath, self.filepathbackup,):
if os.path.exists(p):
os.remove(p)
def setup(self):
self._deletefiles()
with open(self.filepath, 'w') as fout:
fout.write(self.data1)
def _readfile(self, path):
with open(path, 'r') as fin:
return fin.read()
def file_exists(self, f):
return os.path.exists(f)
@print_status
def test_basic(self):
self.setup()
data_is_original = lambda: self._readfile(self.filepath) == self.data1
assert data_is_original()
assert not os.path.exists(self.filepathbackup)
with overwrite_file(self.filepath):
with open(self.filepath, 'w') as fout:
fout.write(self.data2)
assert os.path.exists(self.filepathbackup)
assert not data_is_original()
assert not os.path.exists(self.filepathbackup)
assert not data_is_original()
@print_status
def test_fail(self):
self.setup()
data_is_original = lambda: self._readfile(self.filepath) == self.data1
assert data_is_original()
try:
with overwrite_file(self.filepath):
with open(self.filepath, 'w') as fout:
fout.write(self.data2)
raise FakeException('ohno')
except FakeException:
assert data_is_original()
assert not os.path.exists(self.filepathbackup)
def run(self):
self.test_basic()
self.test_fail()
def __call__(self):
return self.run()
def test_overwritefile():
tof = TestOverwriteFile()
tof()
test_overwritefile()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment