Based on upstream 66b7fb630903fdcf3e83b6b6d56d82e904264a20, adjusted to apply to 1.15.0 & avoid implicit inclusion of changes from other intermediate commits diff --git a/onnx/checker.cc b/onnx/checker.cc index fac56f56..c9fda9b2 100644 --- a/onnx/checker.cc +++ b/onnx/checker.cc @@ -13,7 +13,6 @@ #include #include "onnx/common/file_utils.h" -#include "onnx/common/path.h" #include "onnx/defs/schema.h" #include "onnx/defs/tensor_proto_util.h" #include "onnx/proto_utils.h" @@ -135,85 +134,7 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) { for (const StringStringEntryProto& entry : tensor.external_data()) { if (entry.has_key() && entry.has_value() && entry.key() == "location") { has_location = true; -#ifdef _WIN32 - auto file_path = std::filesystem::path(utf8str_to_wstring(entry.value())); - if (file_path.is_absolute()) { - fail_check( - "Location of external TensorProto ( tensor name: ", - tensor.name(), - ") should be a relative path, but it is an absolute path: ", - entry.value()); - } - auto relative_path = file_path.lexically_normal().make_preferred().wstring(); - // Check that normalized relative path contains ".." on Windows. - if (relative_path.find(L"..", 0) != std::string::npos) { - fail_check( - "Data of TensorProto ( tensor name: ", - tensor.name(), - ") should be file inside the ", - ctx.get_model_dir(), - ", but the '", - entry.value(), - "' points outside the directory"); - } - std::wstring data_path = path_join(utf8str_to_wstring(ctx.get_model_dir()), relative_path); - struct _stat64 buff; - if (_wstat64(data_path.c_str(), &buff) != 0) { - fail_check( - "Data of TensorProto ( tensor name: ", - tensor.name(), - ") should be stored in ", - entry.value(), - ", but it doesn't exist or is not accessible."); - } -#else // POSIX - if (entry.value().empty()) { - fail_check("Location of external TensorProto ( tensor name: ", tensor.name(), ") should not be empty."); - } else if (entry.value()[0] == '/') { - fail_check( - "Location of external TensorProto ( tensor name: ", - tensor.name(), - ") should be a relative path, but it is an absolute path: ", - entry.value()); - } - std::string relative_path = clean_relative_path(entry.value()); - // Check that normalized relative path contains ".." on POSIX - if (relative_path.find("..", 0) != std::string::npos) { - fail_check( - "Data of TensorProto ( tensor name: ", - tensor.name(), - ") should be file inside the ", - ctx.get_model_dir(), - ", but the '", - entry.value(), - "' points outside the directory"); - } - std::string data_path = path_join(ctx.get_model_dir(), relative_path); - // use stat64 to check whether the file exists -#if defined(__APPLE__) || defined(__wasm__) || !defined(__GLIBC__) - struct stat buffer; // APPLE, wasm and non-glic stdlibs do not have stat64 - if (stat((data_path).c_str(), &buffer) != 0) { -#else - struct stat64 buffer; // All POSIX under glibc except APPLE and wasm have stat64 - if (stat64((data_path).c_str(), &buffer) != 0) { -#endif - fail_check( - "Data of TensorProto ( tensor name: ", - tensor.name(), - ") should be stored in ", - data_path, - ", but it doesn't exist or is not accessible."); - } - // Do not allow symlinks or directories. - if (!S_ISREG(buffer.st_mode)) { - fail_check( - "Data of TensorProto ( tensor name: ", - tensor.name(), - ") should be stored in ", - data_path, - ", but it is not regular file."); - } -#endif + resolve_external_data_location(ctx.get_model_dir(), entry.value(), tensor.name()); } } if (!has_location) { @@ -1054,6 +975,93 @@ void check_model(const ModelProto& model, bool full_check, bool skip_opset_compa } } +std::string resolve_external_data_location( + const std::string& base_dir, + const std::string& location, + const std::string& tensor_name) { +#ifdef _WIN32 + auto file_path = std::filesystem::path(utf8str_to_wstring(location)); + if (file_path.is_absolute()) { + fail_check( + "Location of external TensorProto ( tensor name: ", + tensor_name, + ") should be a relative path, but it is an absolute path: ", + location); + } + auto relative_path = file_path.lexically_normal().make_preferred().wstring(); + // Check that normalized relative path contains ".." on Windows. + if (relative_path.find(L"..", 0) != std::string::npos) { + fail_check( + "Data of TensorProto ( tensor name: ", + tensor_name, + ") should be file inside the ", + base_dir, + ", but the '", + location, + "' points outside the directory"); + } + std::wstring data_path = path_join(utf8str_to_wstring(base_dir), relative_path); + struct _stat64 buff; + if (_wstat64(data_path.c_str(), &buff) != 0) { + fail_check( + "Data of TensorProto ( tensor name: ", + tensor_name, + ") should be stored in ", + location, + ", but it doesn't exist or is not accessible."); + } + return wstring_to_utf8str(data_path); +#else // POSIX + if (location.empty()) { + fail_check("Location of external TensorProto ( tensor name: ", tensor_name, ") should not be empty."); + } else if (location[0] == '/') { + fail_check( + "Location of external TensorProto ( tensor name: ", + tensor_name, + ") should be a relative path, but it is an absolute path: ", + location); + } + std::string relative_path = clean_relative_path(location); + // Check that normalized relative path contains ".." on POSIX + if (relative_path.find("..", 0) != std::string::npos) { + fail_check( + "Data of TensorProto ( tensor name: ", + tensor_name, + ") should be file inside the ", + base_dir, + ", but the '", + location, + "' points outside the directory"); + } + std::string data_path = path_join(base_dir, relative_path); + // use stat64 to check whether the file exists +#if defined(__APPLE__) || defined(__wasm__) || !defined(__GLIBC__) + struct stat buffer; // APPLE, wasm and non-glic stdlibs do not have stat64 + if (stat((data_path).c_str(), &buffer) != 0) { +#else + struct stat64 buffer; // All POSIX under glibc except APPLE and wasm have stat64 + if (stat64((data_path).c_str(), &buffer) != 0) { +#endif + fail_check( + "Data of TensorProto ( tensor name: ", + tensor_name, + ") should be stored in ", + data_path, + ", but it doesn't exist or is not accessible."); + } + // Do not allow symlinks or directories. + if (!S_ISREG(buffer.st_mode)) { + fail_check( + "Data of TensorProto ( tensor name: ", + tensor_name, + ") should be stored in ", + data_path, + ", but it is not regular file."); + } + return data_path; +#endif +} + std::set experimental_ops = { "ATen", "Affine", diff --git a/onnx/checker.h b/onnx/checker.h index 6796acab..83012213 100644 --- a/onnx/checker.h +++ b/onnx/checker.h @@ -160,7 +160,10 @@ void check_model_local_functions( void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false); void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false); - +std::string resolve_external_data_location( + const std::string& base_dir, + const std::string& location, + const std::string& tensor_name); bool check_is_experimental_op(const NodeProto& node); } // namespace checker diff --git a/onnx/common/path.h b/onnx/common/path.h index 6eaf5e67..09212747 100644 --- a/onnx/common/path.h +++ b/onnx/common/path.h @@ -31,11 +31,22 @@ inline std::wstring utf8str_to_wstring(const std::string& utf8str) { if (utf8str.size() > INT_MAX) { fail_check("utf8str_to_wstring: string is too long for converting to wstring."); } - int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), NULL, 0); + int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), static_cast(utf8str.size()), NULL, 0); std::wstring ws_str(size_required, 0); - MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), &ws_str[0], size_required); + MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), static_cast(utf8str.size()), &ws_str[0], size_required); return ws_str; } +inline std::string wstring_to_utf8str(const std::wstring& ws_str) { + if (ws_str.size() > INT_MAX) { + fail_check("wstring_to_utf8str: string is too long for converting to UTF-8."); + } + int size_required = + WideCharToMultiByte(CP_UTF8, 0, ws_str.c_str(), static_cast(ws_str.size()), NULL, 0, NULL, NULL); + std::string utf8str(size_required, 0); + WideCharToMultiByte( + CP_UTF8, 0, ws_str.c_str(), static_cast(ws_str.size()), &utf8str[0], size_required, NULL, NULL); + return utf8str; +} #else std::string path_join(const std::string& origin, const std::string& append); diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index bc2594db..83cea68f 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -545,6 +545,8 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "full_check"_a = false, "skip_opset_compatibility_check"_a = false); + checker.def("_resolve_external_data_location", &checker::resolve_external_data_location); + // Submodule `version_converter` auto version_converter = onnx_cpp2py_export.def_submodule("version_converter"); version_converter.doc() = "VersionConverter submodule"; diff --git a/onnx/external_data_helper.py b/onnx/external_data_helper.py index bbc2717f..05c486c6 100644 --- a/onnx/external_data_helper.py +++ b/onnx/external_data_helper.py @@ -8,6 +8,7 @@ import uuid from itertools import chain from typing import Callable, Iterable, Optional +import onnx.onnx_cpp2py_export.checker as c_checker from onnx.onnx_pb import AttributeProto, GraphProto, ModelProto, TensorProto @@ -39,9 +40,9 @@ def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None: base_dir: directory that contains the external data. """ info = ExternalDataInfo(tensor) - file_location = _sanitize_path(info.location) - external_data_file_path = os.path.join(base_dir, file_location) - + external_data_file_path = c_checker._resolve_external_data_location( # type: ignore[attr-defined] + base_dir, info.location, tensor.name + ) with open(external_data_file_path, "rb") as data_file: if info.offset: data_file.seek(info.offset) @@ -259,14 +260,6 @@ def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph) -def _sanitize_path(path: str) -> str: - """Remove path components which would allow traversing up a directory tree from a base path. - - Note: This method is currently very basic and should be expanded. - """ - return path.lstrip("/.") - - def _is_valid_filename(filename: str) -> bool: """Utility to check whether the provided filename is valid.""" exp = re.compile('^[^<>:;,?"*|/]+$') diff --git a/onnx/test/test_external_data.py b/onnx/test/test_external_data.py index 63f6b4ef..bb14d279 100644 --- a/onnx/test/test_external_data.py +++ b/onnx/test/test_external_data.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import itertools import os import pathlib import tempfile @@ -204,6 +205,52 @@ class TestLoadExternalDataSingleFile(TestLoadExternalDataBase): attribute_tensor = new_model.graph.node[0].attribute[0].t np.testing.assert_allclose(to_array(attribute_tensor), self.attribute_value) + @parameterized.parameterized.expand(itertools.product((True, False), (True, False))) + def test_save_external_invalid_single_file_data_and_check( + self, use_absolute_path: bool, use_model_path: bool + ) -> None: + model = onnx.load_model(self.model_filename, self.serialization_format) + + model_dir = os.path.join(self.temp_dir, "save_copy") + os.mkdir(model_dir) + + traversal_external_data_dir = os.path.join( + self.temp_dir, "invlid_external_data" + ) + os.mkdir(traversal_external_data_dir) + + if use_absolute_path: + traversal_external_data_location = os.path.join( + traversal_external_data_dir, "tensors.bin" + ) + else: + traversal_external_data_location = "../invlid_external_data/tensors.bin" + + external_data_dir = os.path.join(self.temp_dir, "external_data") + os.mkdir(external_data_dir) + new_model_filepath = os.path.join(model_dir, "model.onnx") + + def convert_model_to_external_data_no_check(model: ModelProto, location: str): + for tensor in model.graph.initializer: + if tensor.HasField("raw_data"): + set_external_data(tensor, location) + + convert_model_to_external_data_no_check( + model, + location=traversal_external_data_location, + ) + + onnx.save_model(model, new_model_filepath, self.serialization_format) + if use_model_path: + with self.assertRaises(onnx.checker.ValidationError): + _ = onnx.load_model(new_model_filepath, self.serialization_format) + else: + onnx_model = onnx.load_model( + new_model_filepath, self.serialization_format, load_external_data=False + ) + with self.assertRaises(onnx.checker.ValidationError): + load_external_data_for_model(onnx_model, external_data_dir) + @parameterized.parameterized_class( [