141 lines
4.6 KiB
Python
Executable file
141 lines
4.6 KiB
Python
Executable file
#!@python3@
|
|
# SPDX-FileCopyrightText: 2022 Luke Granger-Brown <depot@lukegb.com>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
USAGE: %s [flags] <zip to patch>
|
|
"""
|
|
|
|
import pathlib
|
|
import shutil
|
|
import tempfile
|
|
import zipfile
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_boolean(
|
|
"backup",
|
|
True,
|
|
"Back up world before modifying it to file with .pre-multiworld-infect suffix",
|
|
)
|
|
flags.DEFINE_boolean(
|
|
"dry_run",
|
|
False,
|
|
"Don't actually overwrite the input file with the output",
|
|
)
|
|
flags.DEFINE_boolean(
|
|
"uninfect",
|
|
False,
|
|
"Removes the multiworld Lua rather than injecting it; note that this won't remove any of the other persisted save elements, and may leave players stranded on other surfaces.",
|
|
)
|
|
|
|
_CONTROL_LUA_INFECT_LINE = b'handler.add_lib(require("multiworld"))'
|
|
|
|
|
|
def for_each_in_zip(in_zip, out_zip, callback, skip_files=None):
|
|
skip_files = skip_files or frozenset()
|
|
ret = None
|
|
for zi in in_zip.infolist():
|
|
file_basename = pathlib.PurePosixPath(zi.filename).name
|
|
if file_basename in skip_files:
|
|
logging.info("Found %s at %s: skipping", file_basename, zi.filename)
|
|
continue
|
|
with in_zip.open(zi, "r") as fsrc:
|
|
zi.header_offset = None
|
|
with out_zip.open(zi, "w") as fdst:
|
|
cbret = callback(zi, fsrc, fdst)
|
|
ret = ret or cbret
|
|
return ret
|
|
|
|
|
|
def handle_control_lua(fsrc, fdst):
|
|
# Look to see if we've already infected it...
|
|
buf = fsrc.read()
|
|
lines = buf.split(b"\n")
|
|
if FLAGS.uninfect:
|
|
if _CONTROL_LUA_INFECT_LINE in lines:
|
|
logging.info("Removing infection line from control.lua")
|
|
lines.remove(_CONTROL_LUA_INFECT_LINE)
|
|
else:
|
|
logging.info("control.lua not infected, carrying on")
|
|
else:
|
|
if _CONTROL_LUA_INFECT_LINE not in lines:
|
|
logging.info("control.lua not yet infected, adding our line...")
|
|
last_handler_add_lib = None
|
|
for n, line in enumerate(lines):
|
|
if line.startswith(b"handler.add_lib("):
|
|
last_handler_add_lib = n
|
|
if last_handler_add_lib is None:
|
|
raise ValueError("Can't find handler.add_lib( lines in control.lua!")
|
|
lines.insert(last_handler_add_lib + 1, _CONTROL_LUA_INFECT_LINE)
|
|
else:
|
|
logging.info("control.lua already infected, carrying on")
|
|
|
|
for line in lines:
|
|
fdst.write(line + b"\n")
|
|
|
|
|
|
def handle_file(zi, fsrc, fdst):
|
|
filepath = pathlib.PurePosixPath(zi.filename)
|
|
if filepath.name == "control.lua":
|
|
logging.info("Handling control.lua at %s", filepath)
|
|
handle_control_lua(fsrc, fdst)
|
|
return filepath.parent
|
|
|
|
logging.info("Copying %s", filepath)
|
|
shutil.copyfileobj(fsrc, fdst, 1024 * 8)
|
|
return None
|
|
|
|
|
|
def main(argv):
|
|
if len(argv) != 2:
|
|
raise app.UsageError("Need exactly one argument")
|
|
|
|
file_path = pathlib.Path(argv[1])
|
|
|
|
if FLAGS.backup:
|
|
backup_path = file_path.with_name(f"{file_path.name}.pre-multiworld-infect")
|
|
logging.info("Backing up %s to %s", file_path, backup_path)
|
|
shutil.copy(file_path, backup_path)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpd:
|
|
tmpd_path = pathlib.Path(tmpd)
|
|
tmp_file_path = tmpd_path / file_path.name
|
|
with zipfile.ZipFile(
|
|
tmp_file_path, "w", compression=zipfile.ZIP_DEFLATED
|
|
) as out_zip, zipfile.ZipFile(file_path, "r") as in_zip:
|
|
parent = for_each_in_zip(
|
|
in_zip, out_zip, handle_file, skip_files={"multiworld.lua"}
|
|
)
|
|
|
|
if not FLAGS.uninfect:
|
|
# Add our multiworld.lua.
|
|
arcname = parent / "multiworld.lua"
|
|
logging.info(
|
|
"Adding multiworld.lua from %s to %s", "@multiworldlua@", arcname
|
|
)
|
|
zi = zipfile.ZipInfo.from_file(
|
|
filename="@multiworldlua@",
|
|
arcname=parent / "multiworld.lua",
|
|
strict_timestamps=False,
|
|
)
|
|
with open("@multiworldlua@", "rb") as src, out_zip.open(
|
|
zi, "w"
|
|
) as dest:
|
|
shutil.copyfileobj(src, dest, 1024 * 8)
|
|
if FLAGS.dry_run:
|
|
logging.warning(
|
|
"In dry-run mode: not overwriting output file %s with %s",
|
|
file_path,
|
|
tmp_file_path,
|
|
)
|
|
else:
|
|
shutil.move(tmp_file_path, file_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|