From 305383c4b85dd6c826cb41faa42fd97015f33067 Mon Sep 17 00:00:00 2001 From: Fulgen301 Date: Sun, 26 Aug 2018 20:08:54 +0200 Subject: Rewrite database system with sqlalchemy, add /api/auth, add /api/uploads/comments --- helpers.py | 82 ++++++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 23 deletions(-) (limited to 'helpers.py') diff --git a/helpers.py b/helpers.py index 0e9da70..66c733b 100644 --- a/helpers.py +++ b/helpers.py @@ -15,43 +15,79 @@ import sys import os, re, json, math import requests, hashlib -from bottle import route, run, Bottle, request, static_file, response, hook, HTTPResponse -from bson.objectid import ObjectId +from bottle import route, run, Bottle, request, static_file, response, hook, HTTPResponse, JSONPlugin, install import threading os.chdir(os.path.dirname(__file__)) +from .database import * -io_lock = threading.Lock() +BLOCKSIZE = 1024 ** 2 -def loadJSON(name : str) -> dict: - try: - with open(name, "r") as fobj: - return json.load(fobj) - - except (FileNotFoundError, json.decoder.JSONDecodeError): - return {} +locks = { + "io" : threading.Lock() + } -def saveJSON(obj : dict, name : str) -> None: - with io_lock: - with open(name, "w") as fobj: - json.dump(obj, fobj) +def with_lock(k): + def wrapper(f): + def func(*args, **kwargs): + with locks[k]: + return f(*args, **kwargs) + + return func + return wrapper -database = loadJSON("database.json") -if "entries" not in database: - database["entries"] = {} +def calculateHashForResource(resource : requests.Response) -> object: + hashobj = hashlib.sha1() + + calculateHashForFile(resource.raw, hashobj) + assert(resource.raw.tell()) + if "content-length" not in resource.headers: + resource.headers["Content-Length"] = resource.raw.tell() + return hashobj -def calculateHashForResource(resource : requests.Response, hashobj : object = None) -> object: +def calculateHashForFile(file, hashobj : object = None) -> object: if hashobj is None: hashobj = hashlib.sha1() - l = 0 - for block in resource.iter_content(4096): - l += len(block) + while True: + block = file.read(BLOCKSIZE) + if not block: + break + hashobj.update(block) - if "content-length" not in resource.headers: - resource.headers["Content-Length"] = l return hashobj def notAllowed(): raise HTTPResponse(f"Cannot {request.method} {request.path}") + +@hook('after_request') +def enable_cors(): + response.headers["Access-Control-Allow-Origin"] = "*" + +# Auth + +def calculateUserHash(username : str, password : str) -> object: + return hashlib.sha512(hashlib.sha512(username.encode("utf-8")).digest() + hashlib.sha512(password.encode("utf-8")).digest()) + +def auth_basic(f): + def checkAuth(*args, **kwargs): + session = DBSession() + try: + User.query.filter_by(name=requests.forms["username"], hash=calculateUserHash(request.forms["username"], request.forms["password"]).hexdigest()).first() + except db.orm.exc.NoResultFound: + return HTTPResponse(status=401) + + del request.forms["password"] + return f(*args, **kwargs) + return checkAuth + +class ParryEncoder(json.JSONEncoder): + _default = json.JSONEncoder.default + def default(self, obj): + if isinstance(obj, ObjectId): + return str(obj) + + return self._default(obj) + +install(JSONPlugin(json_dumps=lambda s: json.dumps(s, cls=ParryEncoder))) -- cgit v1.2.3-54-g00ecf