depot/ops/factorio/multiworld/infect.py

141 lines
4.6 KiB
Python
Executable file

#!@python3@
# SPDX-FileCopyrightText: 2022 Luke Granger-Brown <depot@lukegb.com>
#
# SPDX-License-Identifier: Apache-2.0
"""
USAGE: %s [flags] <zip to patch>
"""
import pathlib
import shutil
import tempfile
import zipfile
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
flags.DEFINE_boolean(
"backup",
True,
"Back up world before modifying it to file with .pre-multiworld-infect suffix",
)
flags.DEFINE_boolean(
"dry_run",
False,
"Don't actually overwrite the input file with the output",
)
flags.DEFINE_boolean(
"uninfect",
False,
"Removes the multiworld Lua rather than injecting it; note that this won't remove any of the other persisted save elements, and may leave players stranded on other surfaces.",
)
_CONTROL_LUA_INFECT_LINE = b'handler.add_lib(require("multiworld"))'
def for_each_in_zip(in_zip, out_zip, callback, skip_files=None):
skip_files = skip_files or frozenset()
ret = None
for zi in in_zip.infolist():
file_basename = pathlib.PurePosixPath(zi.filename).name
if file_basename in skip_files:
logging.info("Found %s at %s: skipping", file_basename, zi.filename)
continue
with in_zip.open(zi, "r") as fsrc:
zi.header_offset = None
with out_zip.open(zi, "w") as fdst:
cbret = callback(zi, fsrc, fdst)
ret = ret or cbret
return ret
def handle_control_lua(fsrc, fdst):
# Look to see if we've already infected it...
buf = fsrc.read()
lines = buf.split(b"\n")
if FLAGS.uninfect:
if _CONTROL_LUA_INFECT_LINE in lines:
logging.info("Removing infection line from control.lua")
lines.remove(_CONTROL_LUA_INFECT_LINE)
else:
logging.info("control.lua not infected, carrying on")
else:
if _CONTROL_LUA_INFECT_LINE not in lines:
logging.info("control.lua not yet infected, adding our line...")
last_handler_add_lib = None
for n, line in enumerate(lines):
if line.startswith(b"handler.add_lib("):
last_handler_add_lib = n
if last_handler_add_lib is None:
raise ValueError("Can't find handler.add_lib( lines in control.lua!")
lines.insert(last_handler_add_lib + 1, _CONTROL_LUA_INFECT_LINE)
else:
logging.info("control.lua already infected, carrying on")
for line in lines:
fdst.write(line + b"\n")
def handle_file(zi, fsrc, fdst):
filepath = pathlib.PurePosixPath(zi.filename)
if filepath.name == "control.lua":
logging.info("Handling control.lua at %s", filepath)
handle_control_lua(fsrc, fdst)
return filepath.parent
logging.info("Copying %s", filepath)
shutil.copyfileobj(fsrc, fdst, 1024 * 8)
return None
def main(argv):
if len(argv) != 2:
raise app.UsageError("Need exactly one argument")
file_path = pathlib.Path(argv[1])
if FLAGS.backup:
backup_path = file_path.with_name(f"{file_path.name}.pre-multiworld-infect")
logging.info("Backing up %s to %s", file_path, backup_path)
shutil.copy(file_path, backup_path)
with tempfile.TemporaryDirectory() as tmpd:
tmpd_path = pathlib.Path(tmpd)
tmp_file_path = tmpd_path / file_path.name
with zipfile.ZipFile(
tmp_file_path, "w", compression=zipfile.ZIP_DEFLATED
) as out_zip, zipfile.ZipFile(file_path, "r") as in_zip:
parent = for_each_in_zip(
in_zip, out_zip, handle_file, skip_files={"multiworld.lua"}
)
if not FLAGS.uninfect:
# Add our multiworld.lua.
arcname = parent / "multiworld.lua"
logging.info(
"Adding multiworld.lua from %s to %s", "@multiworldlua@", arcname
)
zi = zipfile.ZipInfo.from_file(
filename="@multiworldlua@",
arcname=parent / "multiworld.lua",
strict_timestamps=False,
)
with open("@multiworldlua@", "rb") as src, out_zip.open(
zi, "w"
) as dest:
shutil.copyfileobj(src, dest, 1024 * 8)
if FLAGS.dry_run:
logging.warning(
"In dry-run mode: not overwriting output file %s with %s",
file_path,
tmp_file_path,
)
else:
shutil.move(tmp_file_path, file_path)
if __name__ == "__main__":
app.run(main)