diff --git a/app.py b/app.py index 99d0b38..53e1dc5 100644 --- a/app.py +++ b/app.py @@ -1,316 +1,178 @@ import time import jwt +import json import httpx -import asyncio import configparser -from flask import Flask, request, jsonify -from flask_sqlalchemy import SQLAlchemy +from flask import Flask, request, jsonify, abort +from sqlalchemy import create_engine, Column, Integer, String +from sqlalchemy.orm import declarative_base, sessionmaker config = configparser.ConfigParser() config.read("config.ini") -REGISTER_AUTH = config.get("AUTH", "REGISTER_AUTH", fallback=None) -MANAGE_AUTH = config.get("AUTH", "MANAGE_AUTH", fallback=None) +REGISTER_AUTH = config["AUTH"]["REGISTER_AUTH"] +MANAGE_AUTH = config["AUTH"]["MANAGE_AUTH"] -SQLALCHEMY_DATABASE_URI = config.get( - "DATABASE", - "SQLALCHEMY_DATABASE_URI", - fallback="sqlite:///device_tokens.db" -) +DATABASE_URI = config["DATABASE"]["SQLALCHEMY_DATABASE_URI"] -auth_key_path = config.get( - "APNS", - "auth_key_path", - fallback="./APNSAuthKey.p8" -) +APNS_KEY_PATH = config["APNS"]["auth_key_path"] +APNS_KEY_ID = config["APNS"]["auth_key_id"] +APNS_TEAM_ID = config["APNS"]["team_id"] +APNS_TOPIC = config["APNS"]["topic"] -auth_key_id = config.get("APNS", "auth_key_id", fallback=None) -team_id = config.get("APNS", "team_id", fallback=None) -topic = config.get("APNS", "topic", fallback=None) +APNS_URL = "https://api.push.apple.com/3/device/" -APNS_URL = "https://api.push.apple.com" -if not all([ - REGISTER_AUTH, - MANAGE_AUTH, - auth_key_id, - team_id, - topic -]): - raise ValueError("Fehlende Pflichtfelder in der Konfiguration!") +Base = declarative_base() +engine = create_engine(DATABASE_URI, echo=False) +SessionLocal = sessionmaker(bind=engine) + + +class DeviceToken(Base): + __tablename__ = "device_tokens" + + id = Column(Integer, primary_key=True) + token = Column(String, unique=True, nullable=False) + + +Base.metadata.create_all(engine) app = Flask(__name__) -app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI -app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - -db = SQLAlchemy(app) - -class DeviceToken(db.Model): - id = db.Column(db.Integer, primary_key=True) - - device_token = db.Column( - db.String(256), - nullable=False, - unique=True - ) - -def authenticate(req, required_token): - auth_header = req.headers.get("Authorization") - return auth_header == f"Bearer {required_token}" - -with open(auth_key_path, "r") as f: - APNS_PRIVATE_KEY = f.read() -def generate_apns_token(): - return jwt.encode( - { - "iss": team_id, - "iat": int(time.time()) - }, - APNS_PRIVATE_KEY, +def get_auth_header(): + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + return None + return auth.replace("Bearer ", "", 1) + + +def require_register_auth(): + if get_auth_header() != REGISTER_AUTH: + abort(401) + + +def require_manage_auth(): + if get_auth_header() != MANAGE_AUTH: + abort(401) + + +def load_apns_key(): + with open(APNS_KEY_PATH, "r") as f: + return f.read() + + +def create_apns_jwt(): + headers = { + "alg": "ES256", + "kid": APNS_KEY_ID + } + + payload = { + "iss": APNS_TEAM_ID, + "iat": int(time.time()) + } + + token = jwt.encode( + payload, + load_apns_key(), algorithm="ES256", - headers={ - "alg": "ES256", - "kid": auth_key_id - } + headers=headers ) + return token -async def send_notification( - device_token, - title, - body -): - try: - jwt_token = generate_apns_token() +def send_apns_notification(device_token: str, payload: dict): + jwt_token = create_apns_jwt() - payload = { - "aps": { - "alert": { - "title": title, - "body": body - }, - "sound": "default", - "badge": 1 - } - } + headers = { + "authorization": f"bearer {jwt_token}", + "apns-topic": APNS_TOPIC, + "apns-push-type": "alert" + } - headers = { - "authorization": f"bearer {jwt_token}", - "apns-topic": topic, - "apns-push-type": "alert" - } + url = APNS_URL + device_token - url = f"{APNS_URL}/3/device/{device_token}" + with httpx.Client(http2=True) as client: + response = client.post(url, headers=headers, json=payload) - async with httpx.AsyncClient( - http2=True, - timeout=30.0 - ) as client: + return response.status_code, response.text - response = await client.post( - url, - json=payload, - headers=headers - ) - if response.status_code != 200: - - print( - f"APNS Fehler " - f"{response.status_code} " - f"für {device_token}: " - f"{response.text}" - ) - - else: - - print( - f"Push erfolgreich " - f"an {device_token}" - ) - - except Exception as e: - - print( - f"Fehler beim Senden " - f"an {device_token}: {e}" - ) @app.route("/api/registerDeviceToken", methods=["POST"]) -def register_device_token(): +def register_device(): + require_register_auth() - if not authenticate(request, REGISTER_AUTH): - return jsonify({ - "error": "Unauthorized" - }), 401 + data = request.get_json() + if not data or "device_token" not in data: + return jsonify({"error": "device_token required"}), 400 - try: + token = data["device_token"] - data = request.get_json() + session = SessionLocal() + exists = session.query(DeviceToken).filter_by(token=token).first() - if not data or "device_token" not in data: - return jsonify({ - "error": "device_token fehlt" - }), 400 + if not exists: + session.add(DeviceToken(token=token)) + session.commit() - device_token = data["device_token"] + session.close() - existing = DeviceToken.query.filter_by( - device_token=device_token - ).first() - - if existing: - return jsonify({ - "message": "Bereits registriert" - }), 200 - - db.session.add( - DeviceToken(device_token=device_token) - ) - - db.session.commit() - - return jsonify({ - "message": "Device Token registriert" - }), 200 - - except Exception as e: - - db.session.rollback() - - return jsonify({ - "error": "Interner Serverfehler", - "details": str(e) - }), 500 + return jsonify({"status": "registered"}) @app.route("/api/notify", methods=["POST"]) def notify(): + require_manage_auth() - if not authenticate(request, MANAGE_AUTH): - return jsonify({ - "error": "Unauthorized" - }), 401 + data = request.get_json() - try: + if not data: + return jsonify({"error": "invalid json"}), 400 - data = request.get_json() + severity = data.get("severity") + message = data.get("notification") - severity = data.get("severity") - notification_text = data.get("notification") + if severity not in ["info", "warning", "urgent", "danger"]: + return jsonify({"error": "invalid severity"}), 400 - if not severity or not notification_text: - return jsonify({ - "error": ( - "severity und notification " - "sind erforderlich" - ) - }), 400 + if not message: + return jsonify({"error": "notification required"}), 400 - severity_icons = { - "info": "", - "warning": "⚠️ ", - "urgent": "❗️ ", - "danger": "❌ " - } + payload = { + "aps": { + "alert": message, + "sound": "default" + }, + "severity": severity + } - if severity not in severity_icons: - return jsonify({ - "error": "Ungültige severity" - }), 400 + session = SessionLocal() + tokens = session.query(DeviceToken).all() + session.close() - title = ( - f"{severity_icons[severity]}" - f"{notification_text}" - ) + results = [] - async def send_all(): - - tasks = [] - - for token in DeviceToken.query.all(): - - tasks.append( - send_notification( - device_token=token.device_token, - title=title, - body=notification_text - ) - ) - - await asyncio.gather(*tasks) - - asyncio.run(send_all()) - - return jsonify({ - "message": "Benachrichtigungen gesendet" - }), 200 - - except Exception as e: - - return jsonify({ - "error": "Interner Serverfehler", - "details": str(e) - }), 500 - - -@app.route("/showdevicetokens", methods=["GET"]) -def show_device_tokens(): - - tokens = DeviceToken.query.all() + for t in tokens: + status, resp = send_apns_notification(t.token, payload) + results.append({ + "token": t.token, + "status": status, + "response": resp + }) return jsonify({ - "device_tokens": [ - token.device_token - for token in tokens - ] + "sent": len(results), + "results": results }) -@app.route( - "/api/deleteAllDeviceTokens", - methods=["POST"] -) -def delete_all_device_tokens(): - - if not authenticate(request, MANAGE_AUTH): - return jsonify({ - "error": "Unauthorized" - }), 401 - - try: - - db.session.query(DeviceToken).delete() - - db.session.commit() - - return jsonify({ - "message": "Alle Device Tokens gelöscht" - }), 200 - - except Exception as e: - - db.session.rollback() - - return jsonify({ - "error": "Interner Serverfehler", - "details": str(e) - }), 500 - if __name__ == "__main__": - - with app.app_context(): - db.create_all() - - app.run( - host="0.0.0.0", - port=3000 - ) \ No newline at end of file + app.run(host="0.0.0.0", port=3000, debug=True) \ No newline at end of file