Skip to content

Instantly share code, notes, and snippets.

@gamapat
Last active August 7, 2024 08:40
Show Gist options
  • Select an option

  • Save gamapat/f6a41312e58f721d9906a4886fb7959e to your computer and use it in GitHub Desktop.

Select an option

Save gamapat/f6a41312e58f721d9906a4886fb7959e to your computer and use it in GitHub Desktop.
python_for_backend_gists
def login(user: User, session: Session):
if not check_user_exists(user, session):
raise RuntimeError('User does not exist')
user_from_db = session.query(User).filter(User.username == user.username).first()
if user_from_db.password != user.password:
raise RuntimeError('Password is incorrect')
# ok, just return
return
def check_admin(user: User, session: Session):
user = session.query(User).filter(User.username == user.username).first()
if user.is_admin != 1:
raise RuntimeError('User is not admin')
def add_user(user: User, session: Session):
if check_user_exists(user, session):
raise RuntimeError(f'User {user} already exists')
user.is_admin = 0
session.add(user)
session.commit()
def remove_user(user: User, session: Session):
if not check_user_exists(user, session):
raise RuntimeError(f'User {user} does not exist')
session.query(User).filter(User.username == user.username).delete()
session.commit()
def list_users(session: Session) -> 'list[User]':
users = session.query(User).all()
return users
def check_user_exists(user: User, session: Session):
return session.query(User).filter(User.username == user.username).first() is not None
def add_packet(packet: Packet, session: Session):
last_packet = session.query(Packet).order_by(Packet.packet_id.desc()).first()
if last_packet is None:
packet_id = 1
else:
packet_id = last_packet.packet_id + 1
packet.packet_id = packet_id
session.add(packet)
session.commit()
def query_packets_user(user: User, size_range, time_range, session: Session) -> 'list[Packet]':
size_min, size_max = size_range.split(',')
time_min, time_max = time_range.split(',')
packets = session.query(Packet).filter(Packet.user == user, Packet.size.between(size_min, size_max), Packet.time.between(time_min, time_max)).all()
return packets
def query_packets_admin(size_range, time_range, session: Session) -> 'list[Packet]':
size_min, size_max = size_range.split(',')
time_min, time_max = time_range.split(',')
packets = session.query(Packet).filter(Packet.size.between(size_min, size_max), Packet.time.between(time_min, time_max)).all()
return packets
def get_total(session: Session):
total_packets = session.query(Packet).count()
total_size = session.query(functions.sum(Packet.size)).scalar()
return total_packets, total_size
def get_average(session: Session):
total_packets, total_size = get_total(session)
return total_size / total_packets
def get_packet_plot(session: Session) -> plt:
packets = session.query(Packet.size, Packet.time).all()
# visuzalize throughput with scatter plot
df = pd.DataFrame(packets, columns=['packet_size', 'packet_time'])
# set index as pakcet_time
df.set_index('packet_time', inplace=True)
df = df.groupby('packet_time').sum()
plt.clf()
seaborn.scatterplot(x='packet_time', y='packet_size', data=df)
return plt
def get_throughput(session: Session) -> plt:
packets = session.query(Packet.size, Packet.time).all()
# visuzalize throughput with line plot
df = pd.DataFrame(packets, columns=['packet_size', 'packet_time'])
# set index as pakcet_time
df.set_index('packet_time', inplace=True)
df = df.groupby('packet_time').sum()
df = df.reindex(list(range(df.index.min(), df.index.max() + 1)), fill_value=0.0)
df_rolling_avg = df.rolling(1000).sum()
# rename column to throughput bytes/s
df_rolling_avg.rename(columns={'packet_size': 'throughput bytes/s'}, inplace=True)
plt.clf()
seaborn.lineplot(x='packet_time', y='throughput bytes/s', data=df_rolling_avg)
return plt
def main():
# add cli option to login
parser = argparse.ArgumentParser()
# username with short option -u and long option --username
parser.add_argument("-lu", "--login_username", help="username", required=True)
# password with short option -p and long option --password
parser.add_argument("-lp", "--login_password", help="password", required=True)
# add cli suboption to add regular user
subparsers = parser.add_subparsers(required=True)
add_user_parser = subparsers.add_parser(name="add_user", help='add regular user')
add_user_parser.add_argument("-u", "--username", help="username", required=True)
add_user_parser.add_argument("-p", "--password", help="password", required=True)
# ...
with backend.get_session() as session:
add_user_parser.set_defaults(func=partial(add_user, session=session))
remove_user_parser.set_defaults(func=partial(remove_user, session=session))
list_users_parser.set_defaults(func=partial(list_users, session=session))
add_packet_parser.set_defaults(func=partial(add_packet, session=session))
query_packets_parser.set_defaults(func=partial(query_packets, session=session))
get_total_parser.set_defaults(func=partial(get_total, session=session))
get_average_parser.set_defaults(func=partial(get_average, session=session))
get_throughput_parser.set_defaults(func=partial(get_throughput, session=session))
get_visualized_packets_parser.set_defaults(func=partial(get_packet_plot, session=session))
args = parser.parse_args()
login(args, session)
args.func(args)
# check user and password against database
def login(args, session: Session):
# hash password with sha256
password = hashlib.sha256(args.login_password.encode('utf-8')).hexdigest()
user = User(username=args.login_username, password=password)
return backend.login(user, session)
# add user to database
def add_user(args, session: Session):
login_user = User(username=args.login_username, password='', is_admin=0)
backend.check_admin(login_user, session)
password = hashlib.sha256(args.password.encode('utf-8')).hexdigest()
user = User(username=args.username, password=password)
backend.add_user(user, session)
print("User added")
# ...
def list_users(args, session: Session):
login_user = User(username=args.login_username, password='')
backend.check_admin(login_user, session)
users = backend.list_users(session)
# print header of a table
print('username\tpassword\tis_admin')
# print users in a nice format
for user in users:
print(f"{user.username}\t{user.password}\t{user.is_admin}")
def add_packet(args, session: Session):
packet = Packet(size=args.size, time=args.time, username=args.login_username)
backend.add_packet(packet, session)
print("Packet added")
def query_packets(args, session: Session):
size_range = args.size_range
time_range = args.time_range
login_user = User(username=args.login_username, password='')
try:
backend.check_admin(login_user, session)
packets = backend.query_packets_admin(size_range, time_range, session)
except RuntimeError:
packets = backend.query_packets_user(login_user, size_range, time_range, session)
# print header of a table
print('packet_id\tpacket_size\tpacket_time\tuser')
# print rest of the table
for packet in packets:
print(f"{packet.packet_id}\t{packet.size}\t{packet.time}\t{packet.username}")
def get_total(args, session: Session):
total_packets, total_size = backend.get_total(session)
print(f'total packets: {total_packets}')
print(f'total size: {total_size}')
def get_average(args, session: Session):
average = backend.get_average(session)
print(f'average packet size: {average}')
def get_throughput(args, session: Session):
local_plt = backend.get_throughput(session)
# save to local file
local_plt.savefig('cli_throughput.png')
def get_packet_plot(args, session: Session):
local_plt = backend.get_packet_plot(session)
# save to local file
local_plt.savefig('cli_packets.png')
app = Flask(__name__)
app.config["JWT_SECRET_KEY"] = "test" # replace with your secret key
app.register_blueprint(user, url_prefix='/user')
app.register_blueprint(packet, url_prefix='/packet')
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog="flask_interface")
parser.add_argument("-c", "--certfile", help="path to tls certificate", type=str, required=False)
parser.add_argument("-k", "--keyfile", help="path to tls key", type=str, required=False)
options = parser.parse_args()
use_https = ("certfile" in dir(options) and options.certfile and "keyfile" in dir(options) and options.keyfile)
with backend.get_session() as session:
backend.create_tables()
backend.add_admin(session)
jwt.init_app(app)
http_server = (
WSGIServer("0.0.0.0:5000", app, keyfile=options.keyfile, certfile=options.certfile)
if use_https else
WSGIServer( "0.0.0.0:5000", app))
http_server.serve_forever()
@user.route('/login', methods=['POST'])
def login():
# login
username = request.json.get('username', None)
password = request.json.get('password', None)
if username is None or password is None:
return jsonify({"msg": "Missing username or password"}), HTTPStatus.BAD_REQUEST
with backend.get_session() as session:
password = hashlib.sha256(password.encode('utf-8')).hexdigest()
user_obj = model.User(username=username, password=password, is_admin=None)
try:
backend.login(user_obj, session)
access_token = create_access_token(identity=username)
User.blacklist.discard(access_token)
return jsonify(access_token=access_token), HTTPStatus.OK
except RuntimeError as ex:
logger.error(ex)
return jsonify({"msg": "Bad username or password"}), HTTPStatus.UNAUTHORIZED
@user.route('/logout', methods=['POST'])
def logout():
verify_jwt_in_request()
# logout
jti = get_jwt()['jti']
# if already in blacklist - return message that user is already logged out
if jti in User.blacklist:
return jsonify({"msg": "Already logged out"}), HTTPStatus.OK
User.blacklist.add(jti)
return jsonify({"msg": "Successfully logged out"}), HTTPStatus.OK
class User(object):
blacklist = set()
@user.route('', methods=['GET'])
def get():
verify_jwt_in_request()
# get list of users
with backend.get_session() as session:
user_obj = model.User(username=get_jwt_identity(), password=None, is_admin=None)
try:
backend.check_admin(user_obj, session)
except RuntimeError:
return jsonify({"msg": "You are not admin"}), HTTPStatus.FORBIDDEN
return jsonify([usr.to_dict() for usr in backend.list_users(session)]), HTTPStatus.OK
@user.route('', methods=['POST'])
def post():
# ...
username = request.json.get('username', None)
password = request.json.get('password', None)
is_admin = request.json.get('is_admin', 0)
if username is None or password is None:
return jsonify({"msg": "Missing username or password"}), HTTPStatus.BAD_REQUEST
user_obj = model.User(username=username, password=password, is_admin=is_admin)
backend.add_user(user_obj, session)
return jsonify({"msg": "User added"}), HTTPStatus.OK
# ...
@packet.route('/plot', methods=['GET'])
def plot():
verify_jwt_in_request()
# get plot
with backend.get_session() as session:
plt = backend.get_packet_plot(session)
# get figure and set it's size to 12inch x 8inch
fig = plt.gcf()
fig.set_size_inches(12, 8)
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# send file to client
return send_file(buf, mimetype='image/png')
# ...
import jwt
import datetime
SECRET_KEY = 'test' # replace with your secret key
def get_token(username):
payload = {
'username': username,
'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=30)
}
token = jwt.encode(payload, SECRET_KEY, algorithm='HS256')
return token
def check_token_blacklisted(token):
return token not in check_token_blacklisted.blacklist
check_token_blacklisted.blacklist = set()
def blacklist_token(token):
check_token_blacklisted.blacklist.add(token)
def get_user(request: 'django.http.request.HttpRequest'):
# get Authorization header from django request
token = request.META.get('HTTP_AUTHORIZATION').split()[1]
if token in check_token_blacklisted.blacklist:
raise RuntimeError('Token blacklisted')
payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256'])
# check expiration time
exp = payload.get('exp')
if datetime.datetime.utcnow() > datetime.datetime.fromtimestamp(exp):
raise RuntimeError('Token expired')
username = payload.get('username')
return username
from django.urls import path
from . import views
urlpatterns = [
path('login', views.LoginView.as_view(), name='login'),
path('user', views.User.as_view(), name='user'),
path('packet', views.Packet.as_view(), name='packet'),
path('packet/total', views.Total.as_view(), name='total'),
path('packet/average', views.Average.as_view(), name='average'),
path('packet/throughput', views.Throughput.as_view(), name='throughput'),
path('packet/plot', views.PacketPlot.as_view(), name='plot'),
]
def check_authorization(request):
if 'HTTP_AUTHORIZATION' not in request.META:
return HttpResponse(status=HTTPStatus.UNAUTHORIZED)
token = request.META['HTTP_AUTHORIZATION'].split()[1]
if not auth_with_jwt.check_token_blacklisted(token):
return HttpResponse(json.dumps({"msg": "Token already blacklisted"}), status=HTTPStatus.BAD_REQUEST)
return None
@method_decorator(csrf_exempt, name='dispatch')
class LoginView(View):
def post(self, request):
data = json.loads(request.body)
username = data.get('username')
password = data.get('password')
hashed_password = hashlib.sha256(password.encode('utf-8')).hexdigest()
with get_session() as session:
try:
user = model.User(username=username, password=hashed_password)
backend.login(user, session)
token = auth_with_jwt.get_token(username)
return HttpResponse(json.dumps({'access_token': token}))
except RuntimeError as ex:
logger.error(ex)
return HttpResponse(json.dumps({"msg": "Bad username or password"}), status=HTTPStatus.UNAUTHORIZED)
def delete(self, request):
res = check_authorization(request)
if res is not None:
return res
token = request.META['HTTP_AUTHORIZATION'].split()[1]
auth_with_jwt.blacklist_token(token)
return HttpResponse(json.dumps({"msg": "Successfully logged out"}), status=HTTPStatus.OK)
@method_decorator(csrf_exempt, name='dispatch')
class User(View):
def get(self, request):
res = check_authorization(request)
if res is not None:
return res
with get_session() as session:
try:
user = auth_with_jwt.get_user(request)
except RuntimeError as ex:
return HttpResponse(json.dumps({"msg": str(ex)}), status=HTTPStatus.UNAUTHORIZED)
try:
backend.check_admin(user, session)
except RuntimeError:
return HttpResponse(json.dumps({"msg": "You are not admin"}), status=HTTPStatus.FORBIDDEN)
users = backend.list_users(session)
users = [user.to_dict() for user in users]
return HttpResponse(json.dumps(users), status=HTTPStatus.OK)
# ...
class Throughput(View):
def get(self, request):
res = check_authorization(request)
if res is not None:
return res
with get_session() as session:
try:
auth_with_jwt.get_user(request)
except RuntimeError as ex:
return HttpResponse(json.dumps({"msg": str(ex)}), status=HTTPStatus.UNAUTHORIZED)
# save plt to file
plt = backend.get_throughput(session)
# get figure and set it's size to 12inch x 8inch
fig = plt.gcf()
fig.set_size_inches(12, 8)
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
# send file to client
return HttpResponse(buf, content_type='image/png')
# ...
import tornado
import tornado.web
def main():
application = tornado.web.Application([
(r"/login", LoginHandler),
(r"/logout", LogoutHandler),
(r"/user", UserHandler),
(r"/packet", PacketHandler),
(r"/packet/total", GetTotalHandler),
(r"/packet/average", GetAverageHandler),
(r"/packet/throughput", GetThroughputHandler),
(r"/packet/plot", GetPacketPlotHandler),
], cookie_secret="__TODO:_GENERATE_YOUR_OWN_RANDOM_VALUE_HERE__")
# run server
application.listen(5001)
tornado.ioloop.IOLoop.current().start()
if __name__ == "__main__":
main()
class BaseHandler(tornado.web.RequestHandler):
blacklisted_tokens = set()
def get_current_user(self):
if self.get_cookie("username") in BaseHandler.blacklisted_tokens:
return None
return self.get_signed_cookie("username").decode('utf-8')
class LoginHandler(BaseHandler):
def post(self):
data = json.loads(self.request.body.decode('utf-8'))
username = data.get('username', None)
password = data.get('password', None)
if username is None or password is None:
self.write({"msg": "Missing username or password"})
self.set_status(HTTPStatus.BAD_REQUEST)
return
# hash password with sha256
password = hashlib.sha256(password.encode('utf-8')).hexdigest()
user = model.User(username=username, password=password)
with backend.get_session() as session:
try:
backend.login(user, session)
self.write({"msg": "Logined successfully"})
self.set_signed_cookie("username", username)
except RuntimeError as ex:
logger.error(ex)
self.set_status(HTTPStatus.UNAUTHORIZED)
self.write({"msg": "Bad username or password"})
class LogoutHandler(BaseHandler):
@tornado.web.authenticated
def delete(self):
BaseHandler.blacklisted_tokens.add(self.get_cookie("username"))
self.clear_cookie("username")
self.write({"msg": "Successfully logged out"})
class UserHandler(BaseHandler):
@tornado.web.authenticated
def get(self):
with backend.get_session() as session:
try:
backend.check_admin(model.User(username=self.get_current_user()), session)
except RuntimeError:
self.write({"msg": "You are not admin"})
self.set_status(HTTPStatus.FORBIDDEN)
return
users = backend.list_users(session)
self.write(json.dumps([user.to_dict() for user in users]))
# ...
class GetThroughputHandler(BaseHandler):
@tornado.web.authenticated
def get(self):
with backend.get_session() as session:
# save plt to file
plt = backend.get_throughput(session)
# get figure and set it's size to 12inch x 8inch
fig = plt.gcf()
fig.set_size_inches(12, 8)
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
content = buf.read()
# set content type to image/png
self.set_header('Content-Type', 'image/png')
self.write(content)
# ...
def get_session() -> Session:
if get_session.maker is None:
engine = create_engine('sqlite:///' + os.path.join(cur_path, 'database.db'))
engine.connect()
get_session.maker = sessionmaker(bind=engine)
return get_session.maker()
get_session.maker = None
def create_tables():
engine = create_engine('sqlite:///' + os.path.join(cur_path, 'database.db'))
engine.connect()
if inspect(engine).has_table(engine, 'users') and inspect(engine).has_table(engine, 'packets'):
return
User.metadata.create_all(engine, checkfirst=True)
Packet.metadata.create_all(engine, checkfirst=True)
engine.dispose()
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class User(Base):
__tablename__ = 'users'
username = Column(String(250), primary_key=True)
password = Column(String(250), nullable=False)
is_admin = Column(Integer, nullable=False)
class Packet(Base):
__tablename__ = 'packets'
packet_id = Column(Integer, primary_key=True)
size = Column(Integer, nullable=False, name='packet_size')
time = Column(Integer, nullable=False, name='packet_time')
username = Column(String(250), ForeignKey('users.username'), name='user')
user = relationship(User)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment