1
0
mirror of https://github.com/laurivosandi/certidude synced 2024-11-10 15:10:35 +00:00
certidude/certidude/decorators.py

82 lines
2.8 KiB
Python
Raw Normal View History

import click
import ipaddress
import json
import logging
import os
import types
from datetime import date, time, datetime, timedelta
from urllib.parse import urlparse
logger = logging.getLogger("api")
def csrf_protection(func):
"""
Protect resource from common CSRF attacks by checking user agent and referrer
"""
import falcon
def wrapped(self, req, resp, *args, **kwargs):
# Assume curl and python-requests are used intentionally
if req.user_agent.startswith("curl/") or req.user_agent.startswith("python-requests/"):
return func(self, req, resp, *args, **kwargs)
# For everything else assert referrer
referrer = req.headers.get("REFERER")
if referrer:
scheme, netloc, path, params, query, fragment = urlparse(referrer)
if ":" in netloc:
host, port = netloc.split(":", 1)
else:
host, port = netloc, None
if host == req.host:
return func(self, req, resp, *args, **kwargs)
# Kaboom!
logger.warning("Prevented clickbait from '%s' with user agent '%s'",
referrer or "-", req.user_agent)
raise falcon.HTTPForbidden("Forbidden",
"No suitable UA or referrer provided, cross-site scripting disabled")
return wrapped
class MyEncoder(json.JSONEncoder):
def default(self, obj):
from certidude.user import User
if isinstance(obj, ipaddress._IPAddressBase):
return str(obj)
if isinstance(obj, set):
return tuple(obj)
if isinstance(obj, datetime):
return obj.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
if isinstance(obj, date):
return obj.strftime("%Y-%m-%d")
if isinstance(obj, timedelta):
return obj.total_seconds()
if isinstance(obj, types.GeneratorType):
return tuple(obj)
if isinstance(obj, User):
return dict(name=obj.name, given_name=obj.given_name,
surname=obj.surname, mail=obj.mail)
return json.JSONEncoder.default(self, obj)
def serialize(func):
"""
Falcon response serialization
"""
import falcon
def wrapped(instance, req, resp, **kwargs):
2018-01-02 14:49:06 +00:00
retval = func(instance, req, resp, **kwargs)
if not resp.body and not resp.location:
if not req.client_accepts("application/json"):
logger.debug("Client did not accept application/json")
raise falcon.HTTPUnsupportedMediaType(
"Client did not accept application/json")
resp.set_header("Cache-Control", "no-cache, no-store, must-revalidate")
resp.set_header("Pragma", "no-cache")
resp.set_header("Expires", "0")
resp.body = json.dumps(retval, cls=MyEncoder)
return wrapped