added ics support. and some thread safety
This commit is contained in:
82
API.py
82
API.py
@@ -1 +1,81 @@
|
||||
# TODO
|
||||
from flask import Flask, Response
|
||||
from icalendar import Calendar, Event
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
import uuid
|
||||
|
||||
|
||||
class Webserver:
|
||||
def __init__(self, ip, port, db_instance):
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.db = db_instance
|
||||
|
||||
self.app = Flask(__name__)
|
||||
self.tz = pytz.timezone("Europe/Amsterdam")
|
||||
|
||||
self._register_routes()
|
||||
|
||||
def _register_routes(self):
|
||||
self.app.add_url_rule("/", "index", lambda: "Calendar server running ✅")
|
||||
self.app.add_url_rule("/calendar.ics", "calendar", self._calendar_endpoint)
|
||||
|
||||
def _fetch_shifts(self):
|
||||
con = self.db._get_connection()
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("""
|
||||
SELECT shift_start, shift_end, department, description
|
||||
FROM shifts
|
||||
ORDER BY shift_start ASC
|
||||
""")
|
||||
|
||||
rows = cur.fetchall()
|
||||
con.close()
|
||||
|
||||
shifts = []
|
||||
for row in rows:
|
||||
start_str, end_str, department, description = row
|
||||
|
||||
# Convert string → datetime
|
||||
start = self._parse_datetime(start_str)
|
||||
end = self._parse_datetime(end_str)
|
||||
|
||||
shifts.append({
|
||||
"start": start,
|
||||
"end": end,
|
||||
"description": f"{description} ",
|
||||
"department": f"{department} "
|
||||
})
|
||||
|
||||
return shifts
|
||||
|
||||
def _parse_datetime(self, dt_str):
|
||||
dt = datetime.fromisoformat(dt_str)
|
||||
if dt.tzinfo is None:
|
||||
dt = self.tz.localize(dt)
|
||||
return dt
|
||||
|
||||
def _calendar_endpoint(self):
|
||||
cal = Calendar()
|
||||
cal.add('prodid', '-//Shift Calendar//example//')
|
||||
cal.add('version', '2.0')
|
||||
|
||||
shifts = self._fetch_shifts()
|
||||
|
||||
for shift in shifts:
|
||||
event = Event()
|
||||
event.add('uid', str(uuid.uuid4()))
|
||||
event.add('dtstart', shift["start"])
|
||||
event.add('dtend', shift["end"])
|
||||
event.add('summary', shift["department"])
|
||||
event.add('description', shift["description"])
|
||||
event.add('dtstamp', datetime.now(tz=self.tz))
|
||||
|
||||
cal.add_component(event)
|
||||
|
||||
return Response(cal.to_ical(), mimetype="text/calendar")
|
||||
|
||||
def run(self):
|
||||
print(f"🚀 Starting server on {self.ip}:{self.port}")
|
||||
self.app.run(host=self.ip, port=self.port)
|
||||
22
DB.py
22
DB.py
@@ -3,14 +3,17 @@ import sqlite3
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._con = sqlite3.connect("shifts.db")
|
||||
self.db_path = "shifts.db"
|
||||
self._setup_tables()
|
||||
pass
|
||||
|
||||
def _get_connection(self):
|
||||
return sqlite3.connect(self.db_path)
|
||||
|
||||
def _setup_tables(self):
|
||||
# create the shifts table.
|
||||
cur = self._con.cursor()
|
||||
con = self._get_connection()
|
||||
cur = con.cursor()
|
||||
cur.execute('''
|
||||
CREATE TABLE IF NOT EXISTS shifts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -24,12 +27,14 @@ class Database:
|
||||
UNIQUE(shift_start, department) -- No duplicates
|
||||
);
|
||||
''')
|
||||
self._con.commit()
|
||||
con.commit()
|
||||
con.close()
|
||||
|
||||
|
||||
def insert_shifts(self, shifts):
|
||||
# start db connection
|
||||
cur = self._con.cursor()
|
||||
con = self._get_connection()
|
||||
cur = con.cursor()
|
||||
inserted = 0
|
||||
|
||||
for shift in shifts:
|
||||
@@ -51,11 +56,14 @@ class Database:
|
||||
if cur.rowcount > 0:
|
||||
inserted += 1
|
||||
|
||||
self._con.commit()
|
||||
con.commit()
|
||||
con.close()
|
||||
print(f"✅ Inserted {inserted}/{len(shifts)} new shifts")
|
||||
|
||||
def delete_future_shifts(self):
|
||||
cur = self._con.cursor()
|
||||
con = self._get_connection()
|
||||
cur = con.cursor()
|
||||
cur.execute("DELETE FROM shifts WHERE shift_start > current_timestamp")
|
||||
self._con.commit()
|
||||
con.commit()
|
||||
con.close()
|
||||
print(f"✅ Deleted all future shifts")
|
||||
|
||||
5
main.py
5
main.py
@@ -1,5 +1,6 @@
|
||||
from PMT import PMT
|
||||
from DB import Database
|
||||
from API import Webserver
|
||||
from dotenv import load_dotenv
|
||||
from time import sleep
|
||||
import os
|
||||
@@ -12,7 +13,7 @@ def main():
|
||||
# Set up database, webserver and PMT connection
|
||||
db = Database()
|
||||
pmt = PMT()
|
||||
|
||||
webserver = Webserver("0.0.0.0", 8080, db)
|
||||
|
||||
# fetch PMT shifts
|
||||
pmt.login()
|
||||
@@ -27,5 +28,7 @@ def main():
|
||||
db.insert_shifts(shifts)
|
||||
|
||||
|
||||
webserver.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user