def replace_layer_recursive(model, old_layer, new_layer): for name, layer in model._modules.items(): if layer == old_layer: model._modules[name] = new_layer return True elif replace_layer_recursive(layer, old_layer, new_layer): return True return False def replace_all_layer_type_recursive(model, old_layer_type, new_layer): for name, layer in model._modules.items(): if isinstance(layer, old_layer_type): model._modules[name] = new_layer replace_all_layer_type_recursive(layer, old_layer_type, new_layer) def find_layer_types_recursive(model, layer_types): def predicate(layer): return type(layer) in layer_types return find_layer_predicate_recursive(model, predicate) def find_layer_predicate_recursive(model, predicate): result = [] for name, layer in model._modules.items(): if predicate(layer): result.append(layer) result.extend(find_layer_predicate_recursive(layer, predicate)) return result