import asyncio
import concurrent.futures
import dataclasses
import datetime
import json
import os
from typing import Any, Dict, List, Optional, Set, Union

import aiohttp
import attrs
import icalendar
import icalevents.icalparser
from dateutil.tz import UTC
from quart import Quart, Response, render_template


async def parse_ical(calendar_text: str) -> icalendar.Calendar:
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        process_pool, icalendar.Calendar.from_ical, calendar_text
    )


async def serialize_ical(calendar: icalendar.Calendar) -> str:
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(process_pool, calendar.to_ical)


def maybe_fromisoformat(val: Optional[str]) -> Optional[datetime.datetime]:
    if not val:
        return None
    if isinstance(val, datetime.datetime):
        return val
    dt = datetime.datetime.fromisoformat(val)
    if not dt.tzinfo:
        dt = dt.replace(tzinfo=UTC)
    return dt


@attrs.frozen
class Calendar:
    name: str
    source_url: str
    keep_hashtags: Set[str] = attrs.field(converter=frozenset, factory=frozenset)
    drop_hashtags: Set[str] = attrs.field(converter=frozenset, factory=frozenset)
    minimize: bool = True
    skip_minimize_hashtags: Set[str] = attrs.field(
        converter=frozenset, factory=frozenset
    )
    skip_before: Optional[datetime.datetime] = attrs.field(
        converter=maybe_fromisoformat, default=None
    )
    skip_if_declined: Set[str] = attrs.field(converter=frozenset, factory=frozenset)
    retitle_to: Optional[str] = None

    async def fetch_events(self, session: aiohttp.ClientSession) -> icalendar.Calendar:
        async with session.get(self.source_url) as response:
            response.raise_for_status()
            text = await response.text()
        return await parse_ical(text)


def make_calendars(
    in_calendars: Union[List[Dict[str, Any]], List[Calendar]]
) -> List[Calendar]:
    if len(in_calendars) == 0:
        return []
    elif isinstance(in_calendars[0], Calendar):
        return in_calendars
    return [Calendar(**c) for c in in_calendars]


@attrs.frozen
class Config:
    calendars: List[Calendar] = attrs.field(converter=make_calendars)

    async def fetch_calendars(
        self, session: aiohttp.ClientSession
    ) -> List[icalendar.Calendar]:
        return await asyncio.gather(*[c.fetch_events(session) for c in self.calendars])


def load_config(fn: str) -> Config:
    with open(fn, "rt") as f:
        return Config(**json.load(f))


app = Quart(__name__)
config = load_config(os.environ.get("ICALFILTER_CONFIG", "config/config.json"))
process_pool = concurrent.futures.ProcessPoolExecutor(max_workers=4)


def contains_any_hashtag(text: str, hashtags: Set[str]) -> bool:
    return any(hashtag in text for hashtag in hashtags)


def _all_occurrences_before_expensive(
    event: icalendar.Event,
    cutoff: datetime.datetime,
    seen_timezones: Dict[str, icalendar.Timezone],
) -> bool:
    # Recurring events are... more complicated.
    rrule_or_rruleset = icalevents.icalparser.parse_rrule(event)
    try:
        return not rrule_or_rruleset.after(cutoff)
    except TypeError:
        return not rrule_or_rruleset.after(cutoff.replace(tzinfo=None))


async def all_occurrences_before(
    event: icalendar.Event,
    cutoff: datetime.datetime,
    seen_timezones: Dict[str, icalendar.Timezone],
) -> bool:
    parsed_event = icalevents.icalparser.create_event(event, cutoff.tzinfo)
    if not parsed_event.recurring:
        return parsed_event.end < cutoff

    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        process_pool, _all_occurrences_before_expensive, event, cutoff, seen_timezones
    )


async def keep_event(
    source_cal: Calendar,
    event: icalendar.Event,
    seen_timezones: Dict[str, icalendar.Timezone],
) -> bool:
    if source_cal.keep_hashtags:
        if not contains_any_hashtag(event["summary"], source_cal.keep_hashtags):
            return False
    if source_cal.drop_hashtags:
        if contains_any_hashtag(event["summary"], source_cal.drop_hashtags):
            return False
    if source_cal.skip_before:
        if "dtstart" not in event:
            print("no dtstart?", event)
            return False
        if await all_occurrences_before(event, source_cal.skip_before, seen_timezones):
            return False
    if source_cal.skip_if_declined:
        attendees = event.get("attendee", [])
        if not isinstance(attendees, list):
            # Sometimes it's just a plain vCalAddress
            attendees = [attendees]
        for attendee in attendees:
            if str(attendee) not in source_cal.skip_if_declined:
                continue
            if attendee.params.get("PARTSTAT", None) == "DECLINED":
                return False
    return True


def maybe_strip_event(source_cal: Calendar, event: icalendar.Event) -> icalendar.Event:
    if not source_cal.minimize:
        return event
    if source_cal.retitle_to:
        event["summary"] = source_cal.retitle_to
    out_event = icalendar.Event()
    for prop in [
        "uid",
        "dtstamp",
        "summary",
        "dtstart",
        "dtend",
        "duration",
        "recurrence-id",
        "sequence",
        "rrule",
        "rdate",
        "exdate",
    ]:
        if prop in event:
            out_event[prop] = event[prop]
    return out_event


class Cache:
    def __init__(self, timeout, cb):
        self._timeout = timeout
        self._cb = cb
        self._cond = asyncio.Condition()
        self._value = None
        self._value_ttl = None
        self._acquiring_value = False

    def _check_value(self):
        now = datetime.datetime.utcnow()
        if self._value_ttl and now >= self._value_ttl:
            self._value = None
            self._value_ttl = None
        if not self._value_ttl:
            return (None, False)
        return (self._value, True)

    async def get_value(self):
        await self._cond.acquire()
        while True:
            if self._acquiring_value:
                await self._cond.wait()
            value, ok = self._check_value()
            if ok:
                self._cond.release()
                return value

            self._acquiring_value = True
            self._cond.release()
            try:
                value = await self._cb()
            except:
                await self._cond.acquire()
                self._acquiring_value = False
                self._cond.notify_all()
                self._cond.release()
                raise
            await self._cond.acquire()
            self._value = value
            self._value_ttl = datetime.datetime.utcnow() + self._timeout
            self._acquiring_value = False
            self._cond.notify_all()
            self._cond.release()
            return value


async def render_ical():
    async with aiohttp.ClientSession() as session:
        icals = await config.fetch_calendars(session)

    cal = icalendar.Calendar()
    cal.add("prodid", "-//icalfilter//lukegb.com//")
    cal.add("version", "2.0")

    seen_timezones = {}
    for source_cal, source_ical in zip(config.calendars, icals):
        for event in source_ical.subcomponents:
            if isinstance(event, icalendar.Timezone):
                if event["tzid"] in seen_timezones:
                    continue
                seen_timezones[event["tzid"]] = event
                cal.add_component(event)

    for source_cal, source_ical in zip(config.calendars, icals):
        for event in source_ical.subcomponents:
            if not isinstance(event, icalendar.Event):
                continue
            if not await keep_event(source_cal, event, seen_timezones):
                continue
            cal.add_component(maybe_strip_event(source_cal, event))

    return await serialize_ical(cal)


ical_cache = Cache(datetime.timedelta(hours=3), render_ical)


@app.get("/ical.ics")
async def render_ical_view():
    cal_text = await ical_cache.get_value()
    return Response(cal_text, mimetype="text/calendar")


@app.get("/")
async def index():
    return await render_template("index.html")


if __name__ == "__main__":
    import sys

    print('This is intended to be run with "quart run".', file=sys.stderr)
    sys.exit(1)