Skip to content

Instantly share code, notes, and snippets.

@singingwolfboy
Last active September 6, 2024 03:28
Show Gist options
  • Save singingwolfboy/2fca1de64950d5dfed72 to your computer and use it in GitHub Desktop.
Save singingwolfboy/2fca1de64950d5dfed72 to your computer and use it in GitHub Desktop.
Want to run your Flask tests with CSRF protections turned on, to make sure that CSRF works properly in production as well? Here's an excellent way to do it!
# Want to run your Flask tests with CSRF protections turned on, to make sure
# that CSRF works properly in production as well? Here's an excellent way
# to do it!
# First some imports. I'm assuming you're using Flask-WTF for CSRF protection.
import flask
from flask.testing import FlaskClient as BaseFlaskClient
from flask_wtf.csrf import generate_csrf
# Flask's assumptions about an incoming request don't quite match up with
# what the test client provides in terms of manipulating cookies, and the
# CSRF system depends on cookies working correctly. This little class is a
# fake request that forwards along requests to the test client for setting
# cookies.
class RequestShim(object):
"""
A fake request that proxies cookie-related methods to a Flask test client.
"""
def __init__(self, client):
self.client = client
def set_cookie(self, key, value='', *args, **kwargs):
"Set the cookie on the Flask test client."
server_name = flask.current_app.config["SERVER_NAME"] or "localhost"
return self.client.set_cookie(
server_name, key=key, value=value, *args, **kwargs
)
def delete_cookie(self, key, *args, **kwargs):
"Delete the cookie on the Flask test client."
server_name = flask.current_app.config["SERVER_NAME"] or "localhost"
return self.client.delete_cookie(
server_name, key=key, *args, **kwargs
)
# We're going to extend Flask's built-in test client class, so that it knows
# how to look up CSRF tokens for you!
class FlaskClient(BaseFlaskClient):
@property
def csrf_token(self):
# First, we'll wrap our request shim around the test client, so that
# it will work correctly when Flask asks it to set a cookie.
request = RequestShim(self)
# Next, we need to look up any cookies that might already exist on
# this test client, such as the secure cookie that powers `flask.session`,
# and make a test request context that has those cookies in it.
environ_overrides = {}
self.cookie_jar.inject_wsgi(environ_overrides)
with flask.current_app.test_request_context(
"/login", environ_overrides=environ_overrides,
):
# Now, we call Flask-WTF's method of generating a CSRF token...
csrf_token = generate_csrf()
# ...which also sets a value in `flask.session`, so we need to
# ask Flask to save that value to the cookie jar in the test
# client. This is where we actually use that request shim we made!
flask.current_app.save_session(flask.session, request)
# And finally, return that CSRF token we got from Flask-WTF.
return csrf_token
# Feel free to define other methods on this test client. You can even
# use the `csrf_token` property we just defined, like we're doing here!
def login(self, email, password):
return self.post("/login", data={
"email": email,
"password": password,
"csrf_token": self.csrf_token,
}, follow_redirects=True)
def logout(self):
return self.get("/logout", follow_redirects=True)
# To hook up this extended test client class to your Flask application,
# assign it to the `test_client_class` property, like this:
app = Flask(__name__)
app.test_client_class = FlaskClient
# Now in your tests, you can request a test client the same way
# that you normally do:
client = app.test_client()
# But now, `client` is an instance of the class we defined!
# In your tests, you can call the methods you defined, like this:
client.login('user@example.com', 'passw0rd')
# And any time you need to pass a CSRF token, just use the `csrf_token`
# property, like this:
client.post("/user/1", data={
"favorite_color": "blue",
"csrf_token": client.csrf_token,
})
@QQmberling
Copy link

QQmberling commented Nov 5, 2021

Thank you for that solution!

My experience with it:

In new versions of flask i guess method save_session was moved to another classes so i got an error::
E AttributeError: 'Flask' object has no attribute 'save_session'
Solved by replacing line
flask.current_app.save_session(flask.session, request)
with
flask.current_app.session_interface.save_session(flask.current_app, flask.session, request)
Also i got an error
E AttributeError: 'RequestShim' object has no attribute 'vary'
and @youngsoul's comment helped me

Just an FYI - when I tried to use this in my test I received the following error message:

" # Add a "Vary: Cookie" header if the session was accessed at all.

if session.accessed:

  response.vary.add('Cookie')

E AttributeError: 'RequestShim' object has no attribute 'vary'
"

I added the following to the init method of RequestShim:
self.vary = set({})

@TomGoBravo
Copy link

I applied the fixes in https://gist.github.com/singingwolfboy/2fca1de64950d5dfed72?permalink_comment_id=3952788#gistcomment-3952788 and still got an error where property csrf_token calls flask.current_app was called (lib/python3.7/site-packages/flask/globals.py:47 RuntimeError: Working outside of application context). I fixed this by replacing flask.current_app with self.application. I don't know why the calls in RequestShim work as flask.current_app.

@TomGoBravo
Copy link

Fixing this after upgrading from werkzeug from 2.1 to 2.3, which removes the undocumented cookie_jar seemed too hard so I dug around and found flask.g.csrf_token which I access with:

    with test_app.app_context(), test_app.test_client(user=user) as c:
        c.get('/tourist/delete/place/3')  # GET to create the CSRF token
        response = c.post(f'/tourist/delete/place/3', data=dict(confirm=True,
                                                             csrf_token=flask.g.csrf_token))

@whitgroves
Copy link

Fixing this after upgrading from werkzeug from 2.1 to 2.3, which removes the undocumented cookie_jar seemed too hard so I dug around and found flask.g.csrf_token which I access with:

    with test_app.app_context(), test_app.test_client(user=user) as c:
        c.get('/tourist/delete/place/3')  # GET to create the CSRF token
        response = c.post(f'/tourist/delete/place/3', data=dict(confirm=True,
                                                             csrf_token=flask.g.csrf_token))

Thank you @TomGoBravo!

@jkittner
Copy link

jkittner commented Mar 5, 2024

I am still getting: The CSRF session token is missing....

@karabomaila
Copy link

The csrf token is still missing.

@jkittner , were you able to figure it out?

@jkittner
Copy link

@karabomaila, I think I ended up only testing the case where the token is missing.

@brihall
Copy link

brihall commented Aug 6, 2024

Flask 3.0.3
Werkzeug 3.0.3
pytest 8.3.2

After applying all the changes from here https://gist.github.com/singingwolfboy/2fca1de64950d5dfed72?permalink_comment_id=3952788#gistcomment-3952788
and here https://gist.github.com/singingwolfboy/2fca1de64950d5dfed72?permalink_comment_id=3952788#gistcomment-3952788

I was getting TypeError: Client.set_cookie got multiple values for argument 'key'

These are the changes I made to RequestShim.set_cookie and RequestShim.delete_cookie

def set_cookie(self, key, value="", *args, **kwargs):
      "Set the cookie on the Flask test client."
      kwargs["domain"] = self.client.application.config["SERVER_NAME"] or "localhost"
      return self.client.set_cookie(
          key=key, value=value, *args, **kwargs
      )

def delete_cookie(self, key, *args, **kwargs):
        "Delete the cookie on the Flask test client."
        kwargs["domain"] = self.client.application.config["SERVER_NAME"] or "localhost"
        return self.client.delete_cookie(key=key, *args, **kwargs)

@mconigliaro
Copy link

mconigliaro commented Sep 6, 2024

I'm sorry everyone thinks this is so complicated. This is all you have to do (in 2024):

url = url_for(".whatever")
client.get(url) # set g.csrf_token
response = client.post(
    url,
    data={
        "foo": "bar"    
        "csrf_token": g.csrf_token
    }
)

Turn that pattern into a helper:

class BetterFlaskLoginClient(FlaskLoginClient):
    def post_with_csrf(self, url: str, **kwargs):
        self.get(url)
        return self.post(url, data={**{"csrf_token": g.csrf_token}, **kwargs})

Then you can just:

response = client.post_with_csrf(url_for(".whatever"), foo="bar")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment