#!/usr/bin/env python3

import json
import pathlib
import re
import sys
from enum import Enum
from typing import Dict, Iterator, Optional

import requests
import solv
from absl import app, flags
from attrs import define, field, frozen, asdict

_USERNAME = flags.DEFINE_string("username", "", "Factorio username")
_TOKEN = flags.DEFINE_string("token", "", "Factorio token")
_FACTORIO_VERSION = flags.DEFINE_string(
    "factorio_version", "1.1.77", "Factorio version to filter to"
)


class DependencyType(Enum):
    INCOMPATIBLE_WITH = "!"
    OPTIONAL = "?"
    HIDDEN_OPTIONAL = "(?)"
    LOAD_ORDER_IGNORED = "~"
    HARD = ""

    @property
    def solv_flags(self):
        return {
            DependencyType.INCOMPATIBLE_WITH: solv.SOLVABLE_CONFLICTS,
            DependencyType.OPTIONAL: solv.SOLVABLE_RECOMMENDS,
            DependencyType.HIDDEN_OPTIONAL: solv.SOLVABLE_SUGGESTS,
            DependencyType.LOAD_ORDER_IGNORED: solv.SOLVABLE_REQUIRES,
            DependencyType.HARD: solv.SOLVABLE_REQUIRES,
        }.get(self)


class VersionConstraintType(Enum):
    NONE = "none"
    LESS = "<"
    LESS_EQUAL = "<="
    EQUAL = "="
    GREATER_EQUAL = ">="
    GREATER = ">"

    @property
    def solv_flags(self):
        return {
            VersionConstraintType.LESS: solv.REL_LT,
            VersionConstraintType.LESS_EQUAL: solv.REL_LT | solv.REL_EQ,
            VersionConstraintType.GREATER: solv.REL_GT,
            VersionConstraintType.GREATER_EQUAL: solv.REL_GT | solv.REL_EQ,
            VersionConstraintType.EQUAL: solv.REL_EQ,
        }.get(self)


_DEPENDENCY_RE = re.compile(
    r"^(?:(?P<dependency_type>[!?~]|\(\?\))\s*)?(?P<mod_name>[^><=]+?)(?:\s*(?P<version_constraint_type>[<>]=?|=)\s*(?P<version_constraint>[^\s]+))?$"
)


class DependencySpecificationError(ValueError):
    pass


@frozen
class Dependency:
    dependency_type: DependencyType
    dependent_on: str
    version_constraint_type: VersionConstraintType
    version_constraint: Optional[str]

    @classmethod
    def from_str(cls, dep_str):
        match = _DEPENDENCY_RE.match(dep_str)
        if not match:
            raise DependencySpecificationError(dep_str)
        d = match.groupdict()
        return Dependency(
            dependency_type=DependencyType(d["dependency_type"] or ""),
            dependent_on=d["mod_name"],
            version_constraint_type=VersionConstraintType(
                d["version_constraint_type"] or "none"
            ),
            version_constraint=d["version_constraint"],
        )

    def to_solv(self):
        if self.version_constraint_type == VersionConstraintType.NONE:
            return self.dependent_on
        return f"{self.dependent_on} {self.version_constraint_type.value} {self.version_constraint}"


@frozen
class Mod:
    name: str
    version: str
    file_name: str
    download_url: str
    sha1: str


@define
class ModFile:
    mods: Dict[str, Mod]


class ModAPI:
    def __init__(self, sess=None):
        self.sess = sess or requests.session()
        self._base = "https://mods.factorio.com/api/mods"

    def _get_json_or_cached(self, cache_key, *args, **kwargs):
        p = pathlib.Path(f"cache/{cache_key}.json")
        if p.exists():
            with open(p, "rt") as f:
                return json.load(f)

        resp = self.sess.get(*args, **kwargs)
        resp.raise_for_status()
        data = resp.json()
        with open(p, "wt") as f:
            json.dump(data, f)
        return data

    def all_mods(self, factorio_version):
        if factorio_version.count(".") >= 2:
            factorio_version = ".".join(factorio_version.split(".")[:2])
        page = 1
        while True:
            data = self._get_json_or_cached(
                f"all_mods.page{page}",
                self._base,
                params={
                    "hide_deprecated": "true",
                    "page": page,
                    "page_size": "max",
                    "version": factorio_version,
                },
            )
            for result in data["results"]:
                yield result
            page_count = (data.get("pagination", {}) or {}).get("page_count", 0)
            if page >= page_count:
                break
            page += 1

    def fetch_mod_info(self, name):
        return self._get_json_or_cached(f"mod.{name}", f"{self._base}/{name}/full")

    def fetch_mod(self, name, version):
        all_data = self.fetch_mod_info(name)
        for release in all_data['releases']:
            if release['version'] == version:
                return Mod(
                    name=all_data['name'],
                    version=release['version'],
                    file_name=release['file_name'],
                    download_url=release['download_url'],
                    sha1=release['sha1'],
                )
        raise KeyError(version)


class ModAPIRepo:
    def __init__(self, api, pool):
        self.api = api
        self.pool = pool
        self.handle = pool.add_repo("Factorio ModAPI")
        self.handle.appdata = self
        self.handle.priority = 99
        self.populated = set()

    def _populate_dependencies(self, solvable, dependencies):
        conflicts = []
        requires = []
        suggests = []
        for dep_obj in dependencies:
            dep_id = self.pool.str2id(dep_obj.dependent_on)
            if dep_obj.version_constraint_type != VersionConstraintType.NONE:
                ver_id = self.pool.str2id(dep_obj.version_constraint)
                dep_id = self.pool.rel2id(
                    dep_id, ver_id, dep_obj.version_constraint_type.solv_flags
                )
            solvable.add_deparray(dep_obj.dependency_type.solv_flags, dep_id)

    def _populate_release(
        self, all_data, release, *, load_incompatible=False, load_optional=False
    ):
        known_dependencies = set()
        try:
            dependencies = [
                Dependency.from_str(dep)
                for dep in release["info_json"].get("dependencies", [])
            ]
        except DependencySpecificationError:
            print(
                f"couldn't parse dependencies for {all_data['name']} {release['version']}"
            )
            return set()
        for dep in dependencies:
            if (
                not load_incompatible
                and dep.dependency_type == DependencyType.INCOMPATIBLE_WITH
            ):
                continue
            if not load_optional and dep.dependency_type in (
                DependencyType.HIDDEN_OPTIONAL,
                DependencyType.OPTIONAL,
            ):
                continue
            known_dependencies.add(dep.dependent_on)
        repodata = self.handle.add_repodata(flags=0)
        solvable = self.handle.add_solvable()
        solvable.name = all_data["name"]
        solvable.evr = release["version"]
        self._populate_dependencies(solvable, dependencies)
        solvable.add_deparray(
            solv.SOLVABLE_PROVIDES,
            self.pool.rel2id(
                self.pool.str2id(solvable.name),
                self.pool.str2id(release["version"]),
                solv.REL_EQ,
            ),
        )
        solvable.add_deparray(
            solv.SOLVABLE_REQUIRES,
            self.pool.rel2id(
                self.pool.str2id("base"),
                self.pool.str2id(release["info_json"]["factorio_version"]),
                solv.REL_GT | solv.REL_EQ,
            ),
        )
        solvable.add_deparray(
            solv.SOLVABLE_REQUIRES,
            self.pool.rel2id(
                self.pool.str2id("base"),
                self.pool.str2id(
                    _next_factorio_version(release["info_json"]["factorio_version"])
                ),
                solv.REL_LT,
            ),
        )
        return known_dependencies

    def _populate_tree(
        self,
        factorio_version,
        mod_name,
        *,
        load_incompatible=False,
        load_optional=False,
    ):
        if mod_name in self.populated:
            return False
        self.populated.add(mod_name)
        known_dependencies_across_all_versions = set()
        all_data = self.api.fetch_mod_info(mod_name)
        for release in all_data["releases"]:
            known_dependencies_across_all_versions |= self._populate_release(
                all_data,
                release,
                load_incompatible=load_incompatible,
                load_optional=load_optional,
            )
        for dep in known_dependencies_across_all_versions:
            try:
                self._populate_tree(factorio_version, dep)
            except requests.exceptions.HTTPError as ex:
                print(mod_name, "->", dep, ex)
        return True

    def populate_tree(
        self,
        factorio_version,
        mod_name,
        *,
        load_incompatible=False,
        load_optional=False,
    ):
        if self._populate_tree(
            factorio_version,
            mod_name,
            load_incompatible=load_incompatible,
            load_optional=load_optional,
        ):
            self.handle.create_stubs()

    def populate(self, factorio_version):
        for mod in self.api.all_mods(factorio_version):
            self.populate_tree(factorio_version, mod["name"])


def _next_factorio_version(v):
    bits = v.split(".")
    return f"{bits[0]}.{int(bits[1])+1}"


class InstalledRepo:
    def __init__(self, pool):
        self.pool = pool
        self.handle = pool.add_repo("Factorio Installed")
        self.handle.appdata = self
        self.handle.priority = 99
        pool.installed = self.handle

    def populate(self, factorio_version):
        repodata = self.handle.add_repodata()
        solvable = self.handle.add_solvable()
        solvable.name = "base"
        solvable.evr = factorio_version
        solvable.add_deparray(
            solv.SOLVABLE_PROVIDES,
            self.pool.rel2id(
                self.pool.str2id("base"),
                self.pool.str2id(factorio_version),
                solv.REL_EQ,
            ),
        )
        self.handle.create_stubs()


def main(args):
    if len(args) != 2:
        raise app.UsageError("Requires a path to a mod listfile.")

    pool = solv.Pool()
    api = ModAPI()

    installed_repo = InstalledRepo(pool)
    installed_repo.populate(_FACTORIO_VERSION.value)

    repo = ModAPIRepo(api, pool)
    # repo.populate(_FACTORIO_VERSION.value)

    jobs = []
    with open(args[1], "rt") as f:
        print(f"asked to install (for Factorio {_FACTORIO_VERSION.value}):")
        for ln in f:
            ln = ln.strip()
            dep = Dependency.from_str(ln)
            repo.populate_tree(_FACTORIO_VERSION.value, dep.dependent_on)
            flags = solv.Selection.SELECTION_NAME | solv.Selection.SELECTION_NOCASE
            print(f"  {dep.to_solv()}")
            sel = pool.select(dep.to_solv(), flags)
            if sel.isempty():
                print(f'nothing matches "{ln}" (interpreted as: {dep.to_solv()})')
                sys.exit(1)
            jobs += sel.jobs(solv.Job.SOLVER_INSTALL)

    solver = pool.Solver()
    while True:
        problems = solver.solve(jobs)
        if problems:
            print("problems :(")
            for problem in problems:
                print(f"Problem {problem.id}/{len(problems)}:")
                print(f"  {problem}")
                solutions = problem.solutions()
                for solution in solutions:
                    print(f"  Solution {solution.id}:")
                    for element in solution.elements(True):
                        print(f"  - {element.str()}")
                    print("")
            sys.exit(2)
        break

    trans = solver.transaction()
    if trans.isempty():
        print("nothing to do.")
        sys.exit(0)
    print()
    print("need to install:")
    mod_file = ModFile(mods={})
    for p in trans.newsolvables():
        print(f" - {p.name} {p.evr}")
        mod = api.fetch_mod(p.name, p.evr)
        mod_file.mods[mod.name] = mod
    with open('mods_lock.json', 'wt') as f:
        json.dump(asdict(mod_file), f)


if __name__ == "__main__":
    app.run(main)