Skip to content

Instantly share code, notes, and snippets.

@entwanne
Created October 9, 2024 17:18
Openid producer & consumer
import asyncio
from aiohttp import web
from . import website, identity_provider
async def run_app(app, port):
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', port)
await site.start()
try:
while True:
await asyncio.sleep(60)
finally:
await runner.cleanup()
async def run():
await asyncio.gather(
run_app(website.app, 8000),
run_app(identity_provider.app1, 8001),
run_app(identity_provider.app2, 8002),
)
if __name__ == '__main__':
try:
asyncio.run(run())
except KeyboardInterrupt:
pass
from aiohttp import web
def html_response(body, status=200):
return web.Response(status=status, text=body, content_type='text/html')
import time
import uuid
from urllib.parse import urlencode
import jwt
from aiohttp import web
from .helpers import html_response
routes = web.RouteTableDef()
@routes.get('/.well-known/openid-configuration')
async def get_config(request):
url = f'http://{request.host}'
return web.json_response(
{
'authorization_endpoint': f'{url}/login',
'token_endpoint': f'{url}/token',
}
)
@routes.get('/login')
async def login(request):
email = request.query.get('login_hint', '')
return html_response(f'''
<h1>Login SSO</h1>
<form method="post">
<input type="text" name="email" value="{email}" />
<input type="submit" value="Login" />
</form>
''')
@routes.post('/login')
async def login_post(request):
try:
data = await request.post()
email = data['email']
client_id = request.query['client_id']
assert client_id in request.app.clients
assert request.query['response_type'] == 'code'
assert request.query['scope'] == 'openid email'
uri = request.query['redirect_uri']
state = request.query['state']
nonce = request.query['nonce']
except:
return html_response('<h1>Invalid login</h1>', status=403)
code = str(uuid.uuid4())
request.app.cache[client_id, uri, code] = (email, nonce)
params = {
'state': state,
'code': code,
'scope': 'openid email',
}
uri += '?' + urlencode(params)
raise web.HTTPFound(uri)
@routes.post('/token')
async def get_token(request):
try:
data = await request.post()
code = data['code']
client_id = data['client_id']
assert request.app.clients[client_id] == data['client_secret']
uri = data['redirect_uri']
assert data['grant_type'] == 'authorization_code'
email, nonce = request.app.cache[client_id, uri, code]
user = request.app.users[email]
except:
return html_response('<h1>Invalid login</h1>', status=403)
payload = {
'iss': f'http://{request.host}',
'aud': client_id,
'sub': user['user_id'],
'email': email,
'iat': int(time.time()),
'exp': int(time.time()) + 3600,
'nonce': nonce,
'roles': user['roles'],
}
token = jwt.encode(payload, request.app.secret_key, algorithm='RS256')
return web.json_response({
'access_token': '',
'expires_in': 3600,
'id_token': token,
'scope': 'openid email',
'token_type': 'Bearer',
})
class App(web.Application):
def __init__(self, *, secret_key, clients, users):
super().__init__()
self.add_routes(routes)
self.secret_key = secret_key
self.clients = clients
self.users = users
self.cache = {}
app1 = App(
# private key generated just for this gist
secret_key=b'''
-----BEGIN RSA PRIVATE KEY-----
MIICWwIBAAKBgQDP7omnmWm8oYcTvU68LEIyfgL2RPUcb0Jb7zHc8T5ciL4oFCPa
MDCUac+8y1KwZ2F+7XjhGHHtQ2VzvG7FUgo3F1jt9AI01yEVC/itKfK9dF0SRE3/
hkc4QIVclen4ehUtm9fOlrTmPKJwkHMLgS+fpFSQSYCZJokV6tp6IegbBwIDAQAB
AoGAIeT1ZW+Zj9kYxrv2KLBiPg7Sdsh42CC+xXOxQh3FF5pmvOvDiF6QnLC+3zyZ
hw3jE6isKq0sbQuubvZJbDraVSNglCwreYD/861qzDGFfj0yNkIoM1JboE39PKVH
YG14albUk/xcTIFIWuLV2ZkJnX6XFUBU4Z6AtmFrr/AvtpECQQDvlOlp0cSQf+RF
WD4ROVKkzewAbBtYlZfLSjkV1PzSvzsWZhNnF4HZPoTp+5v/ECyBWLJQ51MbvHEa
Qdl5ln9pAkEA3i5gr0r5SWYCRznzxhUVJj9jgaw5257rZoErD3QyPuhLuxN86CY/
8zivyQ1Krew0FZMUT1e8Mi/nEfdq9XDo7wJAaqwm8VLQ7nXDMY0Eta6MOYzmU5nO
4zZYC8/UcwAOQdebe3GfmftBeXT4wSWiY4LUDhNr8dHNkz/oLAS9zAFSYQJANXFK
aF/MXs9etOtHq9neYHjW4pYexKkRTnK7fNgiTedqb/FcFis9Lq/v+Jht4i/yx029
kOg1qIzFNBkjbS5fmQJAa0aMr1oGRczw1U+pLI0gLMDwedXstjBA3T6xsROffJAD
VeD0nxu6mrE3HRGTQXg5bmR9RHiSnoslBgHa7tjXYg==
-----END RSA PRIVATE KEY-----
''',
clients={'local': 'mysecret'},
users={
'user@foo.com': {'user_id': 'user_xxx', 'roles': ['admin']},
},
)
app2 = App(
# private key generated just for this gist
secret_key=b'''
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQDqbTkftXn7zc4RRjFiktvj9kp/ynFCsCinrDMQr538kBzTAjMM
Mnc65SRsr12vsmoKMJafw9WoZHQ+/YtFAaqLrWhuoDii4PvxP8UKZjLV3hsl3jcv
SEPCUHTeuCC0DJOCzYWklUDQ74EZaB+frApdhdURYMXCU7xA4vKfOQ8J3QIDAQAB
AoGAd/HClJLKAyheE0CS7BiwIYdWvuPZ54Eyi/fzeMoiT6N83An4DHmVv7CWXnWr
vcXPs78AkmYk36/mOoSHyZr8otgk+lhEx/VPfFVR99eWMN7jrcrRhWVP2EgEDveE
LaWaniv1e5ZOsQQI/8x2NQYsHZ8XVl6/w9vwho0/VVJtSIECQQD7dERAdLi6Xbn9
72tbkPsisIGJwCsA/WBVZML5VSfXLfKwpREcjNuqDKxDi4aTGnYZgnyyu68zQjpH
2dxSgmR5AkEA7qonKHe4eug6NpQzMzOjqfHUEoIR4HGsZyJgu963d6Gxp4W6n3Az
DzQ0ygSeg9eJR9KGGOfxoQIVhn/t+evPhQJBAJY/D3UQCTzaOgMAJr7MrzI4Ev6z
Az8RJpl/nnyKpEaPnt5rkwDjLXG7JscpPFzzR7PQYyAMww+2t8os49etL8kCQCKt
Ejds/Z7N8vdPHQlJbXQZsMhnoZVMrtHlSszBFUMKeTdv1Keps6F34lufxDA/trB6
xAUGXjMYjGwYVklDm8kCQCl7yAX8PcpZm9Sz41r6yomDgq8toSTg2hYePcgoeDLo
iAdV29SGrQGTcIUC2WTMxDV5FvncsnazqkeJ96HMwro=
-----END RSA PRIVATE KEY-----
''',
clients={'local': 'mybettersecret'},
users={
'user@bar.com': {'user_id': 'user_yyy', 'roles': ['readonly']},
},
)
aiohttp==3.*
pyjwt==2.*
cryptography
import dataclasses
import uuid
from urllib.parse import urlencode
import jwt
from aiohttp import ClientSession, web
from .helpers import html_response
@dataclasses.dataclass
class User:
id: str
email: str
roles: list[str]
@dataclasses.dataclass
class Provider:
url: str
client_id: str
client_secret: str
public_key: bytes
async def get_config(self):
async with ClientSession() as cli, cli.get(self.url + '/.well-known/openid-configuration') as resp:
return await resp.json()
sessions = {}
openid_providers = {
'foo': Provider(
url='http://localhost:8001',
client_id='local',
client_secret='mysecret',
public_key=b'''
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM/uiaeZabyhhxO9TrwsQjJ+AvZE9RxvQlvvMdzxPlyIvigUI9owMJRp
z7zLUrBnYX7teOEYce1DZXO8bsVSCjcXWO30AjTXIRUL+K0p8r10XRJETf+GRzhA
hVyV6fh6FS2b186WtOY8onCQcwuBL5+kVJBJgJkmiRXq2noh6BsHAgMBAAE=
-----END RSA PUBLIC KEY-----
''',
),
'bar': Provider(
url='http://localhost:8002',
client_id='local',
client_secret='mybettersecret',
public_key=b'''
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAOptOR+1efvNzhFGMWKS2+P2Sn/KcUKwKKesMxCvnfyQHNMCMwwydzrl
JGyvXa+yagowlp/D1ahkdD79i0UBqoutaG6gOKLg+/E/xQpmMtXeGyXeNy9IQ8JQ
dN64ILQMk4LNhaSVQNDvgRloH5+sCl2F1RFgxcJTvEDi8p85DwndAgMBAAE=
-----END RSA PUBLIC KEY-----
''',
),
}
def create_session(response):
session_id = str(uuid.uuid4())
session = sessions[session_id] = {}
response.set_cookie('session_id', session_id)
return session
def get_session(request):
session_id = request.cookies.get('session_id')
return sessions.get(session_id)
def close_session(request, response):
session_id = request.cookies.get('session_id')
sessions.pop(session_id, None)
response.del_cookie('session_id')
def get_current_user(request):
if session := get_session(request):
return session.get('user')
routes = web.RouteTableDef()
@routes.get('/')
async def home(request):
if user := get_current_user(request):
header = 'Home - user logged'
links = {'/me': 'User info', '/logout': 'Logout'}
else:
header = 'Home'
links = {'/login': 'Login'}
return html_response(f'''
<h1>{header}</h1>
<ul>
{''.join(f'<li><a href="{path}">{name}</a></li>' for path, name in links.items())}
</ul>
''')
@routes.get('/login')
async def login(request):
return html_response(f'''
<form method="post">
<label>Provider:
<select name="provider">
{
''.join(
f'<option value="{provider}">{provider}</option>'
for provider in openid_providers
)
}
</select></label><br/>
<label>Email: <input name="email" type="text" /></label><br/>
<input type="submit" value="Login" />
</form>
''')
@routes.post('/login')
async def login_post(request):
data = await request.post()
provider_name = data['provider']
provider = openid_providers[provider_name]
provider_config = await provider.get_config()
email = data['email']
state = str(uuid.uuid4())
nonce = str(uuid.uuid4())
params = {
'client_id': provider.client_id,
'response_type': 'code',
'scope': 'openid email',
'redirect_uri': f'http://localhost:8000/login/{provider_name}/finalize',
'state': state,
'nonce': nonce,
'login_hint': email,
}
response = web.HTTPFound(provider_config['authorization_endpoint'] + '?' + urlencode(params))
session = create_session(response)
session['state'] = state
session['nonce'] = nonce
return response
@routes.get('/login/{provider}/finalize')
async def login_finalize(request):
session = get_session(request)
state = request.query['state']
code = request.query['code']
if not session or session.pop('state', '') != state:
return html_response('<h1>Invalid login</h1>', status=403)
provider_name = request.match_info['provider']
provider = openid_providers[provider_name]
provider_config = await provider.get_config()
data = {
'code': code,
'client_id': provider.client_id,
'client_secret': provider.client_secret,
'redirect_uri': f'http://localhost:8000/login/{provider_name}/finalize',
'grant_type': 'authorization_code',
}
async with ClientSession() as cli, cli.post(provider_config['token_endpoint'], data=data) as resp:
if resp.status == 200:
payload = await resp.json()
else:
return html_response('<h1>Invalid login</h1>', status=403)
token = jwt.decode(payload['id_token'], provider.public_key, algorithms=['RS256'], audience=provider.client_id)
if token['iss'] != provider.url or token['nonce'] != session.pop('nonce', ''):
return html_response('<h1>Invalid login</h1>', status=403)
user = User(
id=f'{provider_name}:{token['sub']}',
email=token['email'],
roles=token['roles'],
)
session['user'] = user
return web.HTTPFound('/')
@routes.get('/logout')
async def logout(request):
response = web.HTTPFound('/')
close_session(request, response)
return response
@routes.get('/me')
async def user_info(request):
if user := get_current_user(request):
return html_response(f'''
<h1>User info</h1>
<ul>
<li>Id: {user.id}</li>
<li>Email: {user.email}</li>
<li>Roles: {', '.join(user.roles)}</li>
</ul>
<a href="/">Home</a>
''')
else:
return html_response('<h1>No user found</h1><a href="/">Home</a>', status=404)
app = web.Application()
app.add_routes(routes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment