diff --git a/API.py b/API.py index f87f5c1..79d0043 100644 --- a/API.py +++ b/API.py @@ -1 +1,81 @@ -# TODO \ No newline at end of file +from flask import Flask, Response +from icalendar import Calendar, Event +from datetime import datetime +import pytz +import uuid + + +class Webserver: + def __init__(self, ip, port, db_instance): + self.ip = ip + self.port = port + self.db = db_instance + + self.app = Flask(__name__) + self.tz = pytz.timezone("Europe/Amsterdam") + + self._register_routes() + + def _register_routes(self): + self.app.add_url_rule("/", "index", lambda: "Calendar server running ✅") + self.app.add_url_rule("/calendar.ics", "calendar", self._calendar_endpoint) + + def _fetch_shifts(self): + con = self.db._get_connection() + cur = con.cursor() + + cur.execute(""" + SELECT shift_start, shift_end, department, description + FROM shifts + ORDER BY shift_start ASC + """) + + rows = cur.fetchall() + con.close() + + shifts = [] + for row in rows: + start_str, end_str, department, description = row + + # Convert string → datetime + start = self._parse_datetime(start_str) + end = self._parse_datetime(end_str) + + shifts.append({ + "start": start, + "end": end, + "description": f"{description} ", + "department": f"{department} " + }) + + return shifts + + def _parse_datetime(self, dt_str): + dt = datetime.fromisoformat(dt_str) + if dt.tzinfo is None: + dt = self.tz.localize(dt) + return dt + + def _calendar_endpoint(self): + cal = Calendar() + cal.add('prodid', '-//Shift Calendar//example//') + cal.add('version', '2.0') + + shifts = self._fetch_shifts() + + for shift in shifts: + event = Event() + event.add('uid', str(uuid.uuid4())) + event.add('dtstart', shift["start"]) + event.add('dtend', shift["end"]) + event.add('summary', shift["department"]) + event.add('description', shift["description"]) + event.add('dtstamp', datetime.now(tz=self.tz)) + + cal.add_component(event) + + return Response(cal.to_ical(), mimetype="text/calendar") + + def run(self): + print(f"🚀 Starting server on {self.ip}:{self.port}") + self.app.run(host=self.ip, port=self.port) \ No newline at end of file diff --git a/DB.py b/DB.py index 2f50a7a..0cd4e6f 100644 --- a/DB.py +++ b/DB.py @@ -3,14 +3,17 @@ import sqlite3 class Database: def __init__(self): - self._con = sqlite3.connect("shifts.db") + self.db_path = "shifts.db" self._setup_tables() pass - + + def _get_connection(self): + return sqlite3.connect(self.db_path) def _setup_tables(self): # create the shifts table. - cur = self._con.cursor() + con = self._get_connection() + cur = con.cursor() cur.execute(''' CREATE TABLE IF NOT EXISTS shifts ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -24,12 +27,14 @@ class Database: UNIQUE(shift_start, department) -- No duplicates ); ''') - self._con.commit() + con.commit() + con.close() def insert_shifts(self, shifts): # start db connection - cur = self._con.cursor() + con = self._get_connection() + cur = con.cursor() inserted = 0 for shift in shifts: @@ -51,11 +56,14 @@ class Database: if cur.rowcount > 0: inserted += 1 - self._con.commit() + con.commit() + con.close() print(f"✅ Inserted {inserted}/{len(shifts)} new shifts") def delete_future_shifts(self): - cur = self._con.cursor() + con = self._get_connection() + cur = con.cursor() cur.execute("DELETE FROM shifts WHERE shift_start > current_timestamp") - self._con.commit() + con.commit() + con.close() print(f"✅ Deleted all future shifts") diff --git a/main.py b/main.py index 3c52fb4..28d91e4 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ from PMT import PMT from DB import Database +from API import Webserver from dotenv import load_dotenv from time import sleep import os @@ -12,7 +13,7 @@ def main(): # Set up database, webserver and PMT connection db = Database() pmt = PMT() - + webserver = Webserver("0.0.0.0", 8080, db) # fetch PMT shifts pmt.login() @@ -27,5 +28,7 @@ def main(): db.insert_shifts(shifts) + webserver.run() + if __name__ == "__main__": main()