Skip to content

Instantly share code, notes, and snippets.

@notbanker
Last active March 10, 2024 17:51
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save notbanker/2be3ed34539c86e22ffdd88fd95ad8bc to your computer and use it in GitHub Desktop.
Save notbanker/2be3ed34539c86e22ffdd88fd95ad8bc to your computer and use it in GitHub Desktop.
Context manager to temporarily pandas set chained assignment warning to None,'warn' or 'raise, then revert
import pandas as pd
class ChainedAssignent:
""" Context manager to temporarily set pandas chained assignment warning. Usage:
with ChainedAssignment():
blah
with ChainedAssignment('error'):
run my code and figure out which line causes the error!
"""
def __init__(self, chained = None):
acceptable = [ None, 'warn','raise']
assert chained in acceptable, "chained must be in " + str(acceptable)
self.swcw = chained
def __enter__( self ):
self.saved_swcw = pd.options.mode.chained_assignment
pd.options.mode.chained_assignment = self.swcw
return self
def __exit__(self, *args):
pd.options.mode.chained_assignment = self.saved_swcw
@shaybensasson
Copy link

Great!
Here is my small contribution (pytest unittests)

import pytest
import pandas as pd
import numpy as np
from pandas.core.common import SettingWithCopyWarning, SettingWithCopyError

from pandas_chained_assignment_warn_handler import PandasChainedAssignmentWarnHandler


def reproduce_warn():
    df = pd.DataFrame(data={'col1': np.arange(11), 'col2': [1] * 11, 'label': [1] * 11})
    df.loc[0, 'col2'] = np.nan

    X, y = df[['col1', 'col2']], df[['label']]
    X.loc[0, 'col2'] = 1.

    assert X.loc[0, 'col2'] == 1.
    assert np.isnan(df.loc[0, 'col2'])  # still nan

def test_reproduce_warn_raises_warn():
    with pytest.warns(SettingWithCopyWarning):
        reproduce_warn()

def test_context_handler_blocks_warn():
    with pytest.warns(None) as record:
        with PandasChainedAssignmentWarnHandler(chained=None):
            reproduce_warn()
        assert len(record) == 0

def test_context_handler_display_warn():
    with pytest.warns(SettingWithCopyWarning):
        with PandasChainedAssignmentWarnHandler(chained='warn'):
            reproduce_warn()

def test_context_handler_display_error():
    with pytest.raises(SettingWithCopyError):
        with PandasChainedAssignmentWarnHandler(chained='raise'):
            reproduce_warn()

@microprediction
Copy link

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment