#!@python3@
# SPDX-FileCopyrightText: 2022 Luke Granger-Brown <depot@lukegb.com>
#
# SPDX-License-Identifier: Apache-2.0
"""
       USAGE: %s [flags] <zip to patch>
"""

import pathlib
import shutil
import tempfile
import zipfile

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_boolean(
    "backup",
    True,
    "Back up world before modifying it to file with .pre-multiworld-infect suffix",
)
flags.DEFINE_boolean(
    "dry_run",
    False,
    "Don't actually overwrite the input file with the output",
)
flags.DEFINE_boolean(
    "uninfect",
    False,
    "Removes the multiworld Lua rather than injecting it; note that this won't remove any of the other persisted save elements, and may leave players stranded on other surfaces.",
)

_CONTROL_LUA_INFECT_LINE = b'handler.add_lib(require("multiworld"))'


def for_each_in_zip(in_zip, out_zip, callback, skip_files=None):
    skip_files = skip_files or frozenset()
    ret = None
    for zi in in_zip.infolist():
        file_basename = pathlib.PurePosixPath(zi.filename).name
        if file_basename in skip_files:
            logging.info("Found %s at %s: skipping", file_basename, zi.filename)
            continue
        with in_zip.open(zi, "r") as fsrc:
            zi.header_offset = None
            with out_zip.open(zi, "w") as fdst:
                cbret = callback(zi, fsrc, fdst)
                ret = ret or cbret
    return ret


def handle_control_lua(fsrc, fdst):
    # Look to see if we've already infected it...
    buf = fsrc.read()
    lines = buf.split(b"\n")
    if FLAGS.uninfect:
        if _CONTROL_LUA_INFECT_LINE in lines:
            logging.info("Removing infection line from control.lua")
            lines.remove(_CONTROL_LUA_INFECT_LINE)
        else:
            logging.info("control.lua not infected, carrying on")
    else:
        if _CONTROL_LUA_INFECT_LINE not in lines:
            logging.info("control.lua not yet infected, adding our line...")
            last_handler_add_lib = None
            for n, line in enumerate(lines):
                if line.startswith(b"handler.add_lib("):
                    last_handler_add_lib = n
            if last_handler_add_lib is None:
                raise ValueError("Can't find handler.add_lib( lines in control.lua!")
            lines.insert(last_handler_add_lib + 1, _CONTROL_LUA_INFECT_LINE)
        else:
            logging.info("control.lua already infected, carrying on")

    for line in lines:
        fdst.write(line + b"\n")


def handle_file(zi, fsrc, fdst):
    filepath = pathlib.PurePosixPath(zi.filename)
    if filepath.name == "control.lua":
        logging.info("Handling control.lua at %s", filepath)
        handle_control_lua(fsrc, fdst)
        return filepath.parent

    logging.info("Copying %s", filepath)
    shutil.copyfileobj(fsrc, fdst, 1024 * 8)
    return None


def main(argv):
    if len(argv) != 2:
        raise app.UsageError("Need exactly one argument")

    file_path = pathlib.Path(argv[1])

    if FLAGS.backup:
        backup_path = file_path.with_name(f"{file_path.name}.pre-multiworld-infect")
        logging.info("Backing up %s to %s", file_path, backup_path)
        shutil.copy(file_path, backup_path)

    with tempfile.TemporaryDirectory() as tmpd:
        tmpd_path = pathlib.Path(tmpd)
        tmp_file_path = tmpd_path / file_path.name
        with zipfile.ZipFile(
            tmp_file_path, "w", compression=zipfile.ZIP_DEFLATED
        ) as out_zip, zipfile.ZipFile(file_path, "r") as in_zip:
            parent = for_each_in_zip(
                in_zip, out_zip, handle_file, skip_files={"multiworld.lua"}
            )

            if not FLAGS.uninfect:
                # Add our multiworld.lua.
                arcname = parent / "multiworld.lua"
                logging.info(
                    "Adding multiworld.lua from %s to %s", "@multiworldlua@", arcname
                )
                zi = zipfile.ZipInfo.from_file(
                    filename="@multiworldlua@",
                    arcname=parent / "multiworld.lua",
                    strict_timestamps=False,
                )
                with open("@multiworldlua@", "rb") as src, out_zip.open(
                    zi, "w"
                ) as dest:
                    shutil.copyfileobj(src, dest, 1024 * 8)
        if FLAGS.dry_run:
            logging.warning(
                "In dry-run mode: not overwriting output file %s with %s",
                file_path,
                tmp_file_path,
            )
        else:
            shutil.move(tmp_file_path, file_path)


if __name__ == "__main__":
    app.run(main)