quantized-custom
资源内容介绍
quantized-custom import contextlibimport loggingimport osfrom distutils.version import LooseVersionfrom functools import wrapsimport torch.onnxfrom torch import _C # noqa: N814from torch.onnx import OperatorExportTypes# from horizon_plugin_pytorch.utils import _log_first_n# from . import _register_onnx_ops # noqa: F401from . import _register_quantized_onnx_ops # noqa: F401TrainingMode = _C._onnx.TrainingMode__all__ = ["export_to_onnx", "export_quantized_onnx"]def _preprocess_graph(func): """Remove dead custom registered ops. Custom registered quantized functions with a tuple returned usually followed by a prim::TupleUnpack node in traced graph. However, if this func results is not used by other nodes, there is no prim::TupleUnpack node followed to unpack the results. Dead code elimination pass in torch can not remove this node(May this custom op is traced as a block and marked as live). So find unused custom ops and delete from graph here. """ @wraps(func) def _preprocess(*args, **kwargs): graph, *args = args assert type(graph) == torch._C.Graph dead_vslz_node = [] for node in graph.nodes(): if node.kind() == "prim::PythonOp" and not node.hasUses(): dead_vslz_node.append(node) for node in dead_vslz_node: node.destroy() return func(graph, *args, **kwargs) return _preprocess# torch 1.10.2 add some logic in onnx shape inference and use std::cerr# print warnings in custom registered ops.# We redirect stderr to null to avoid warnings in each custom op,# do torch.onnx.export and then redirect stderr back.@contextlib.contextmanagerdef _redirect_stderr(): # Note: Directly use sys.stderr.fileno() cause 'Tee' error in CI/CD # stderr_fd = sys.stderr.fileno() stderr_fd = 2 fd = os.open("/dev/null", os.O_WRONLY) dup_stderr_fd = os.dup(stderr_fd) try: yield os.dup2(fd, stderr_fd) finally: os.dup2(dup_stderr_fd, stderr_fd) os.close(fd) os.close(dup_stderr_fd)# replace torch.onnx.utils._optimize_graph in torch 1.13 to avoid# process of autograd function inner implementation@contextlib.contextmanagerdef _redirect_opt_graph(): _torch_optimize_graph = torch.onnx.utils._optimize_graph try: if LooseVersion(torch.__version__) >= LooseVersion("1.13"): from ._optimize_graph_helper import _optimize_graph torch.onnx.utils._optimize_graph = _preprocess_graph( _optimize_graph ) yield True else: torch.onnx.utils._optimize_graph = _preprocess_graph( _torch_optimize_graph ) yield False finally: torch.onnx.utils._optimize_graph = _torch_optimize_graph@contextlib.contextmanagerdef _set_is_in_onnx_export_false(): origin_f = torch.onnx.utils.is_in_onnx_export try: if LooseVersion(torch.__version__) >= LooseVersion("1.13"): torch.onnx.utils.is_in_onnx_export = False yield finally: torch.onnx.utils.is_in_onnx_export = origin_fdef export_to_onnx( model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, opset_version=11, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,): r""" Export a (float or qat)model into ONNX format. Args: model (torch.nn.Module/torch.jit.ScriptModule/ScriptFunction): the model to be exported. args (tuple or torch.Tensor): args can be structured either as: 1. ONLY A TUPLE OF ARGUMENTS:: args = (x, y, z) The tuple should contain model inputs such that ``model(*args)`` is a valid invocation of the model. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in the tuple. 2. A TENSOR:: args = torch.Tensor([1]) This is equivalent to a 1-ary tuple of that Tensor. 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: args = (x, {'y': input_y, 'z': input_z}) All but the last element of the tuple will be passed as non-keyword arguments, and named arguments will be set from the last element. If a named argument is not present in the dictionary , it is assigned the default value, or None if a default value is not provided. f: a file-like object or a string containing a file name. A binary protocol buffer will be written to this file. export_params (bool, default True): if True, all parameters will be exported. verbose (bool, default False): if True, prints a description of the model being exported to stdout, doc_string will be added to graph. doc_string may contaion mapping of module scope to node name in future torch onnx. training (enum, default TrainingMode.EVAL): if model.training is False and in training mode if model.training is True. * ``TrainingMode.EVAL``: export the model in inference mode. * ``TrainingMode.PRESERVE``: export the model in inference mode * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations which might interfere with training. input_names (list of str, default empty list): names to assign to the input nodes of the graph, in order. output_names (list of str, default empty list): names to assign to the output nodes of the graph, in order. operator_export_type (enum, default ONNX_FALLTHROUGH): * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops (in the default opset domain). * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops to standard ONNX ops in the default opset domain. * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") are exported as ATen ops. * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so,fall back to exporting an ATen op. opset_version (int, default 11): by default we export the model to the opset version of the onnx submodule. do_constant_folding (bool, default False): Apply the constant-folding optimization. Constant-folding will replace some of the ops that have all constant inputs with pre-computed constant nodes. dynamic_axes (dict<str, list(int)/dict<int, str>>, default empty dict): By default the exported model will have the shapes of all input and output tensors set to exactly match those given in ``args`` (and ``example_outputs`` when that arg is required). To specify axes of tensors as dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or ``output_names``. * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index. keep_initializers_as_inputs (bool, default None): If True, all the initializers (typically corresponding to parameters) in t