Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Utilities for instrumenting a torch model. | |
| Trace will hook one layer at a time. | |
| TraceDict will hook multiple layers at once. | |
| subsequence slices intervals from Sequential modules. | |
| get_module, replace_module, get_parameter resolve dotted names. | |
| set_requires_grad recursively sets requires_grad in module parameters. | |
| Script from memit source code | |
| MIT License | |
| Copyright (c) 2022 Kevin Meng | |
| Permission is hereby granted, free of charge, to any person obtaining a copy | |
| of this software and associated documentation files (the "Software"), to deal | |
| in the Software without restriction, including without limitation the rights | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| copies of the Software, and to permit persons to whom the Software is | |
| furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all | |
| copies or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| SOFTWARE. | |
| """ | |
| import contextlib | |
| import copy | |
| import inspect | |
| from collections import OrderedDict | |
| import torch | |
| class Trace(contextlib.AbstractContextManager): | |
| """ | |
| To retain the output of the named layer during the computation of | |
| the given network: | |
| with Trace(net, 'layer.name') as ret: | |
| _ = net(inp) | |
| representation = ret.output | |
| A layer module can be passed directly without a layer name, and | |
| its output will be retained. By default, a direct reference to | |
| the output object is returned, but options can control this: | |
| clone=True - retains a copy of the output, which can be | |
| useful if you want to see the output before it might | |
| be modified by the network in-place later. | |
| detach=True - retains a detached reference or copy. (By | |
| default the value would be left attached to the graph.) | |
| retain_grad=True - request gradient to be retained on the | |
| output. After backward(), ret.output.grad is populated. | |
| retain_input=True - also retains the input. | |
| retain_output=False - can disable retaining the output. | |
| edit_output=fn - calls the function to modify the output | |
| of the layer before passing it the rest of the model. | |
| fn can optionally accept (output, layer) arguments | |
| for the original output and the layer name. | |
| stop=True - throws a StopForward exception after the layer | |
| is run, which allows running just a portion of a model. | |
| """ | |
| def __init__( | |
| self, | |
| module, | |
| layer=None, | |
| retain_output=True, | |
| retain_input=False, | |
| clone=False, | |
| detach=False, | |
| retain_grad=False, | |
| edit_output=None, | |
| stop=False, | |
| ): | |
| """ | |
| Method to replace a forward method with a closure that | |
| intercepts the call, and tracks the hook so that it can be reverted. | |
| """ | |
| retainer = self | |
| self.layer = layer | |
| if layer is not None: | |
| module = get_module(module, layer) | |
| def retain_hook(m, inputs, output): | |
| if retain_input: | |
| retainer.input = recursive_copy( | |
| inputs[0] if len(inputs) == 1 else inputs, | |
| clone=clone, | |
| detach=detach, | |
| retain_grad=False, | |
| ) # retain_grad applies to output only. | |
| if edit_output: | |
| output = invoke_with_optional_args( | |
| edit_output, output=output, layer=self.layer | |
| ) | |
| if retain_output: | |
| retainer.output = recursive_copy( | |
| output, clone=clone, detach=detach, retain_grad=retain_grad | |
| ) | |
| # When retain_grad is set, also insert a trivial | |
| # copy operation. That allows in-place operations | |
| # to follow without error. | |
| if retain_grad: | |
| output = recursive_copy(retainer.output, clone=True, detach=False) | |
| if stop: | |
| raise StopForward() | |
| return output | |
| self.registered_hook = module.register_forward_hook(retain_hook) | |
| self.stop = stop | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, type, value, traceback): | |
| self.close() | |
| if self.stop and issubclass(type, StopForward): | |
| return True | |
| def close(self): | |
| self.registered_hook.remove() | |
| class TraceDict(OrderedDict, contextlib.AbstractContextManager): | |
| """ | |
| To retain the output of multiple named layers during the computation | |
| of the given network: | |
| with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret: | |
| _ = net(inp) | |
| representation = ret['layer1.name1'].output | |
| If edit_output is provided, it should be a function that takes | |
| two arguments: output, and the layer name; and then it returns the | |
| modified output. | |
| Other arguments are the same as Trace. If stop is True, then the | |
| execution of the network will be stopped after the last layer | |
| listed (even if it would not have been the last to be executed). | |
| """ | |
| def __init__( | |
| self, | |
| module, | |
| layers=None, | |
| retain_output=True, | |
| retain_input=False, | |
| clone=False, | |
| detach=False, | |
| retain_grad=False, | |
| edit_output=None, | |
| stop=False, | |
| ): | |
| self.stop = stop | |
| def flag_last_unseen(it): | |
| try: | |
| it = iter(it) | |
| prev = next(it) | |
| seen = set([prev]) | |
| except StopIteration: | |
| return | |
| for item in it: | |
| if item not in seen: | |
| yield False, prev | |
| seen.add(item) | |
| prev = item | |
| yield True, prev | |
| for is_last, layer in flag_last_unseen(layers): | |
| self[layer] = Trace( | |
| module=module, | |
| layer=layer, | |
| retain_output=retain_output, | |
| retain_input=retain_input, | |
| clone=clone, | |
| detach=detach, | |
| retain_grad=retain_grad, | |
| edit_output=edit_output, | |
| stop=stop and is_last, | |
| ) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, type, value, traceback): | |
| self.close() | |
| if self.stop and issubclass(type, StopForward): | |
| return True | |
| def close(self): | |
| for layer, trace in reversed(self.items()): | |
| trace.close() | |
| class StopForward(Exception): | |
| """ | |
| If the only output needed from running a network is the retained | |
| submodule then Trace(submodule, stop=True) will stop execution | |
| immediately after the retained submodule by raising the StopForward() | |
| exception. When Trace is used as context manager, it catches that | |
| exception and can be used as follows: | |
| with Trace(net, layername, stop=True) as tr: | |
| net(inp) # Only runs the network up to layername | |
| print(tr.output) | |
| """ | |
| pass | |
| def recursive_copy(x, clone=None, detach=None, retain_grad=None): | |
| """ | |
| Copies a reference to a tensor, or an object that contains tensors, | |
| optionally detaching and cloning the tensor(s). If retain_grad is | |
| true, the original tensors are marked to have grads retained. | |
| """ | |
| if not clone and not detach and not retain_grad: | |
| return x | |
| if isinstance(x, torch.Tensor): | |
| if retain_grad: | |
| if not x.requires_grad: | |
| x.requires_grad = True | |
| x.retain_grad() | |
| elif detach: | |
| x = x.detach() | |
| if clone: | |
| x = x.clone() | |
| return x | |
| # Only dicts, lists, and tuples (and subclasses) can be copied. | |
| if isinstance(x, dict): | |
| return type(x)({k: recursive_copy(v) for k, v in x.items()}) | |
| elif isinstance(x, (list, tuple)): | |
| return type(x)([recursive_copy(v) for v in x]) | |
| else: | |
| assert False, f"Unknown type {type(x)} cannot be broken into tensors." | |
| def subsequence( | |
| sequential, | |
| first_layer=None, | |
| last_layer=None, | |
| after_layer=None, | |
| upto_layer=None, | |
| single_layer=None, | |
| share_weights=False, | |
| ): | |
| """ | |
| Creates a subsequence of a pytorch Sequential model, copying over | |
| modules together with parameters for the subsequence. Only | |
| modules from first_layer to last_layer (inclusive) are included, | |
| or modules between after_layer and upto_layer (exclusive). | |
| Handles descent into dotted layer names as long as all references | |
| are within nested Sequential models. | |
| If share_weights is True, then references the original modules | |
| and their parameters without copying them. Otherwise, by default, | |
| makes a separate brand-new copy. | |
| """ | |
| assert (single_layer is None) or ( | |
| first_layer is last_layer is after_layer is upto_layer is None | |
| ) | |
| if single_layer is not None: | |
| first_layer = single_layer | |
| last_layer = single_layer | |
| first, last, after, upto = [ | |
| None if d is None else d.split(".") | |
| for d in [first_layer, last_layer, after_layer, upto_layer] | |
| ] | |
| return hierarchical_subsequence( | |
| sequential, | |
| first=first, | |
| last=last, | |
| after=after, | |
| upto=upto, | |
| share_weights=share_weights, | |
| ) | |
| def hierarchical_subsequence( | |
| sequential, first, last, after, upto, share_weights=False, depth=0 | |
| ): | |
| """ | |
| Recursive helper for subsequence() to support descent into dotted | |
| layer names. In this helper, first, last, after, and upto are | |
| arrays of names resulting from splitting on dots. Can only | |
| descend into nested Sequentials. | |
| """ | |
| assert (last is None) or (upto is None) | |
| assert (first is None) or (after is None) | |
| if first is last is after is upto is None: | |
| return sequential if share_weights else copy.deepcopy(sequential) | |
| assert isinstance(sequential, torch.nn.Sequential), ( | |
| ".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential" | |
| ) | |
| including_children = (first is None) and (after is None) | |
| included_children = OrderedDict() | |
| # A = current level short name of A. | |
| # AN = full name for recursive descent if not innermost. | |
| (F, FN), (L, LN), (A, AN), (U, UN) = [ | |
| (d[depth], (None if len(d) == depth + 1 else d)) | |
| if d is not None | |
| else (None, None) | |
| for d in [first, last, after, upto] | |
| ] | |
| for name, layer in sequential._modules.items(): | |
| if name == F: | |
| first = None | |
| including_children = True | |
| if name == A and AN is not None: # just like F if not a leaf. | |
| after = None | |
| including_children = True | |
| if name == U and UN is None: | |
| upto = None | |
| including_children = False | |
| if including_children: | |
| # AR = full name for recursive descent if name matches. | |
| FR, LR, AR, UR = [ | |
| n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN] | |
| ] | |
| chosen = hierarchical_subsequence( | |
| layer, | |
| first=FR, | |
| last=LR, | |
| after=AR, | |
| upto=UR, | |
| share_weights=share_weights, | |
| depth=depth + 1, | |
| ) | |
| if chosen is not None: | |
| included_children[name] = chosen | |
| if name == L: | |
| last = None | |
| including_children = False | |
| if name == U and UN is not None: # just like L if not a leaf. | |
| upto = None | |
| including_children = False | |
| if name == A and AN is None: | |
| after = None | |
| including_children = True | |
| for name in [first, last, after, upto]: | |
| if name is not None: | |
| raise ValueError("Layer %s not found" % ".".join(name)) | |
| # Omit empty subsequences except at the outermost level, | |
| # where we should not return None. | |
| if not len(included_children) and depth > 0: | |
| return None | |
| result = torch.nn.Sequential(included_children) | |
| result.training = sequential.training | |
| return result | |
| def set_requires_grad(requires_grad, *models): | |
| """ | |
| Sets requires_grad true or false for all parameters within the | |
| models passed. | |
| """ | |
| for model in models: | |
| if isinstance(model, torch.nn.Module): | |
| for param in model.parameters(): | |
| param.requires_grad = requires_grad | |
| elif isinstance(model, (torch.nn.Parameter, torch.Tensor)): | |
| model.requires_grad = requires_grad | |
| else: | |
| assert False, "unknown type %r" % type(model) | |
| def get_module(model, name): | |
| """ | |
| Finds the named module within the given model. | |
| """ | |
| for n, m in model.named_modules(): | |
| if n == name: | |
| return m | |
| raise LookupError(name) | |
| def get_parameter(model, name): | |
| """ | |
| Finds the named parameter within the given model. | |
| """ | |
| for n, p in model.named_parameters(): | |
| if n == name: | |
| return p | |
| raise LookupError(name) | |
| def replace_module(model, name, new_module): | |
| """ | |
| Replaces the named module within the given model. | |
| """ | |
| if "." in name: | |
| parent_name, attr_name = name.rsplit(".", 1) | |
| model = get_module(model, parent_name) | |
| # original_module = getattr(model, attr_name) | |
| setattr(model, attr_name, new_module) | |
| def invoke_with_optional_args(fn, *args, **kwargs): | |
| """ | |
| Invokes a function with only the arguments that it | |
| is written to accept, giving priority to arguments | |
| that match by-name, using the following rules. | |
| (1) arguments with matching names are passed by name. | |
| (2) remaining non-name-matched args are passed by order. | |
| (3) extra caller arguments that the function cannot | |
| accept are not passed. | |
| (4) extra required function arguments that the caller | |
| cannot provide cause a TypeError to be raised. | |
| Ordinary python calling conventions are helpful for | |
| supporting a function that might be revised to accept | |
| extra arguments in a newer version, without requiring the | |
| caller to pass those new arguments. This function helps | |
| support function callers that might be revised to supply | |
| extra arguments, without requiring the callee to accept | |
| those new arguments. | |
| """ | |
| argspec = inspect.getfullargspec(fn) | |
| pass_args = [] | |
| used_kw = set() | |
| unmatched_pos = [] | |
| used_pos = 0 | |
| defaulted_pos = len(argspec.args) - ( | |
| 0 if not argspec.defaults else len(argspec.defaults) | |
| ) | |
| # Pass positional args that match name first, then by position. | |
| for i, n in enumerate(argspec.args): | |
| if n in kwargs: | |
| pass_args.append(kwargs[n]) | |
| used_kw.add(n) | |
| elif used_pos < len(args): | |
| pass_args.append(args[used_pos]) | |
| used_pos += 1 | |
| else: | |
| unmatched_pos.append(len(pass_args)) | |
| pass_args.append( | |
| None if i < defaulted_pos else argspec.defaults[i - defaulted_pos] | |
| ) | |
| # Fill unmatched positional args with unmatched keyword args in order. | |
| if len(unmatched_pos): | |
| for k, v in kwargs.items(): | |
| if k in used_kw or k in argspec.kwonlyargs: | |
| continue | |
| pass_args[unmatched_pos[0]] = v | |
| used_kw.add(k) | |
| unmatched_pos = unmatched_pos[1:] | |
| if len(unmatched_pos) == 0: | |
| break | |
| else: | |
| if unmatched_pos[0] < defaulted_pos: | |
| unpassed = ", ".join( | |
| argspec.args[u] for u in unmatched_pos if u < defaulted_pos | |
| ) | |
| raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.") | |
| # Pass remaining kw args if they can be accepted. | |
| pass_kw = { | |
| k: v | |
| for k, v in kwargs.items() | |
| if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None) | |
| } | |
| # Pass remaining positional args if they can be accepted. | |
| if argspec.varargs is not None: | |
| pass_args += list(args[used_pos:]) | |
| return fn(*pass_args, **pass_kw) | |