Created
September 13, 2021 02:31
-
-
Save jeetu7/377e4bd7056cb222d9e114154ea345e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
following is requirements.txt | |
============ | |
starlette | |
sqlalchemy | |
databases[sqlite] | |
uvicorn | |
imia | |
passlib | |
sqlalchemy_utils | |
============ | |
""" | |
import dataclasses | |
from dataclasses import dataclass | |
import pathlib | |
import databases | |
import sqlalchemy | |
from sqlalchemy import create_engine | |
from sqlalchemy_utils import database_exists, create_database | |
from passlib.hash import pbkdf2_sha1 | |
from starlette.applications import Starlette | |
from starlette.middleware import Middleware | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.requests import Request | |
from starlette.responses import HTMLResponse, RedirectResponse | |
from starlette.routing import Route | |
from starlette.config import Config | |
from imia import AuthenticationMiddleware, InMemoryProvider, LoginManager, SessionAuthenticator | |
config = Config('.env') | |
DATABASE_URL = config('DATABASE_URL') | |
metadata = sqlalchemy.MetaData() | |
users = sqlalchemy.Table( | |
"users", | |
metadata, | |
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), | |
sqlalchemy.Column("fullname", sqlalchemy.String), | |
sqlalchemy.Column("email", sqlalchemy.String), | |
sqlalchemy.Column("hashed_password", sqlalchemy.String), | |
sqlalchemy.Column("completed", sqlalchemy.Boolean), | |
) | |
db = databases.Database(DATABASE_URL) | |
# Create tables if dir does not exists. This should be done using alembic in production. | |
def create_tables(): | |
engine = create_engine(DATABASE_URL) | |
if not database_exists(DATABASE_URL): | |
metadata.create_all(engine) | |
@dataclass | |
class User: | |
"""This is our user model. Any user model must implement UserLike protocol.""" | |
identifier: str = 'root@localhost' | |
password: str = '$pbkdf2$131000$xfhfaw1hrNU6ByAkBKA0Zg$qT.ZZYscSAUS4Btk/Q2rkAZQc5E' # pa$$word | |
scopes: list[str] = dataclasses.field(default_factory=list) | |
def get_display_name(self): | |
return 'User' | |
def get_id(self): | |
return self.identifier | |
def get_hashed_password(self): | |
return self.password | |
def get_scopes(self): | |
return self.scopes | |
secret_key = 'key!' | |
"""For security!""" | |
user_provider = InMemoryProvider({'root@localhost': User()}) | |
"""The class that looks up for a user. you may provide your own for, eg. database user lookup""" | |
password_verifier = pbkdf2_sha1 | |
"""Password checking tool. Password checkers must match PasswordVerifier protocol.""" | |
login_manager = LoginManager(user_provider, password_verifier, secret_key) | |
"""This is the core class of login/logout flow""" | |
def index_view(request: Request) -> HTMLResponse: | |
"""Display welcome page.""" | |
return HTMLResponse("""<a href="/authz/login">Login</a> | <a href="/app/protected1">P1</a> """) | |
async def login_view(request: Request): | |
"""Display login page and handle login POST request.""" | |
error = '' | |
if 'error' in request.query_params: | |
error = '<span style="color:red">invalid credentials</span>' | |
if request.method == 'POST': | |
form = await request.form() | |
email = form['email'] | |
password = form['password'] | |
user_token = await login_manager.login(request, email, password) | |
if user_token: | |
return RedirectResponse('/app/protected1', status_code=302) | |
return RedirectResponse('/authz/login?error=invalid_credentials', status_code=302) | |
return HTMLResponse( | |
""" | |
%s | |
<form method="post"> | |
<label>email <input name="email" value="root@localhost"></label> | |
<label>password <input name="password" type="password" value="pa$$word"></label> | |
<button type="submit">submit</button> | |
</form> | |
""" | |
% error | |
) | |
async def logout_view(request: Request) -> RedirectResponse: | |
"""Handle logout request.""" | |
if request.method == 'POST': | |
await login_manager.logout(request) | |
return RedirectResponse('/authz/login', status_code=302) | |
return RedirectResponse('/app/protected1', status_code=302) | |
async def protected1_view(request: Request) -> HTMLResponse: | |
"""This is our protected area. Only authorized users allowed.""" | |
user = request.auth.display_name | |
return HTMLResponse( | |
""" | |
Hi %s! This is protected 1 app area. | |
<form action="/authz/logout" method="post"> | |
<button>logout</button> | |
</form> | |
""" | |
% user | |
) | |
app = Starlette( | |
debug=True, | |
routes=[ | |
Route('/', index_view), | |
Route('/authz/login', login_view, methods=['GET', 'POST']), | |
Route('/authz/logout', logout_view, methods=['POST']), | |
Route('/app/protected1', protected1_view), | |
], | |
middleware=[ | |
Middleware(SessionMiddleware, secret_key=secret_key), | |
Middleware( | |
AuthenticationMiddleware, | |
authenticators=[SessionAuthenticator(user_provider)], | |
on_failure='redirect', | |
redirect_to='/authz/login', | |
include_patterns=[r'\/app'] | |
# protect /app path | |
), | |
], | |
on_startup=[create_tables, db.connect], | |
on_shutdown=[db.disconnect], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment