|
import gc |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Function |
|
from torch.cuda.amp import custom_bwd, custom_fwd |
|
|
|
import tinycudann as tcnn |
|
|
|
|
|
def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): |
|
B = None |
|
for arg in args: |
|
if isinstance(arg, torch.Tensor): |
|
B = arg.shape[0] |
|
break |
|
out = defaultdict(list) |
|
out_type = None |
|
for i in range(0, B, chunk_size): |
|
out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) |
|
if out_chunk is None: |
|
continue |
|
out_type = type(out_chunk) |
|
if isinstance(out_chunk, torch.Tensor): |
|
out_chunk = {0: out_chunk} |
|
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): |
|
chunk_length = len(out_chunk) |
|
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} |
|
elif isinstance(out_chunk, dict): |
|
pass |
|
else: |
|
print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') |
|
exit(1) |
|
for k, v in out_chunk.items(): |
|
v = v if torch.is_grad_enabled() else v.detach() |
|
v = v.cpu() if move_to_cpu else v |
|
out[k].append(v) |
|
|
|
if out_type is None: |
|
return |
|
|
|
out = {k: torch.cat(v, dim=0) for k, v in out.items()} |
|
if out_type is torch.Tensor: |
|
return out[0] |
|
elif out_type in [tuple, list]: |
|
return out_type([out[i] for i in range(chunk_length)]) |
|
elif out_type is dict: |
|
return out |
|
|
|
|
|
class _TruncExp(Function): |
|
|
|
|
|
@staticmethod |
|
@custom_fwd(cast_inputs=torch.float32) |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return torch.exp(x) |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, g): |
|
x = ctx.saved_tensors[0] |
|
return g * torch.exp(torch.clamp(x, max=15)) |
|
|
|
trunc_exp = _TruncExp.apply |
|
|
|
|
|
def get_activation(name): |
|
if name is None: |
|
return lambda x: x |
|
name = name.lower() |
|
if name == 'none': |
|
return lambda x: x |
|
elif name.startswith('scale'): |
|
scale_factor = float(name[5:]) |
|
return lambda x: x.clamp(0., scale_factor) / scale_factor |
|
elif name.startswith('clamp'): |
|
clamp_max = float(name[5:]) |
|
return lambda x: x.clamp(0., clamp_max) |
|
elif name.startswith('mul'): |
|
mul_factor = float(name[3:]) |
|
return lambda x: x * mul_factor |
|
elif name == 'lin2srgb': |
|
return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) |
|
elif name == 'trunc_exp': |
|
return trunc_exp |
|
elif name.startswith('+') or name.startswith('-'): |
|
return lambda x: x + float(name) |
|
elif name == 'sigmoid': |
|
return lambda x: torch.sigmoid(x) |
|
elif name == 'tanh': |
|
return lambda x: torch.tanh(x) |
|
else: |
|
return getattr(F, name) |
|
|
|
|
|
def dot(x, y): |
|
return torch.sum(x*y, -1, keepdim=True) |
|
|
|
|
|
def reflect(x, n): |
|
return 2 * dot(x, n) * n - x |
|
|
|
|
|
def scale_anything(dat, inp_scale, tgt_scale): |
|
if inp_scale is None: |
|
inp_scale = [dat.min(), dat.max()] |
|
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) |
|
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] |
|
return dat |
|
|
|
|
|
def cleanup(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
tcnn.free_temporary_memory() |
|
|