Shortcuts

mmagic.models.editors.eg3d.renderer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
"""The renderer is a module that takes in rays, decides where to sample along
each ray, and computes pixel colors using the volume rendering equation."""

from typing import Any, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from ..stylegan3.stylegan3_modules import FullyConnectedLayer
from .eg3d_utils import (get_ray_limits_box, inverse_transform_sampling,
                         linspace_batch)


[文档]class EG3DRenderer(BaseModule): """Renderer for EG3D. This class samples render points on each input ray and interpolate the triplane feature corresponding to the points' coordinates. Then, predict each point's RGB feature and density (sigma) by a neural network and calculate the RGB feature of each ray by integration. Different from typical NeRF models, the decoder of EG3DRenderer takes triplane feature of each points as input instead of positional encoding of the coordinates. Args: decoder_cfg (dict): The config to build neural renderer. ray_start (float): The start position of all rays. ray_end (float): The end position of all rays. box_warp (float): The side length of the cube spanned by the triplanes. The box is axis-aligned, centered at the origin. The range of each axis is `[-box_warp/2, box_warp/2]`. If `box_warp=1.8`, it has vertices at the range of axis is `[-0.9, 0.9]`. Defaults to 1. depth_resolution (int): Resolution of depth, as well as the number of points per ray. Defaults to 64. depth_resolution_importance (int): Resolution of depth in hierarchical sampling. Defaults to 64. clamp_mode (str): The clamp mode for density predicted by neural renderer. Defaults to 'softplus'. white_back (bool): Whether render a white background. Defaults to True. projection_mode (str): The projection method to mapping coordinates of render points to plane feature. The usage of this argument please refer to :meth:`self.project_onto_planes` and https://github.com/NVlabs/eg3d/issues/67. Defaults to 'Official'. """ def __init__(self, decoder_cfg: dict, ray_start: float, ray_end: float, box_warp: float = 1, depth_resolution: int = 64, depth_resolution_importance: int = 64, density_noise: float = 0, clamp_mode: str = 'softplus', white_back: bool = True, projection_mode: str = 'Official'): super().__init__() self.decoder = EG3DDecoder(**decoder_cfg) self.ray_start = ray_start self.ray_end = ray_end self.box_warp = box_warp self.depth_resolution = depth_resolution self.depth_resolution_importance = depth_resolution_importance self.density_noise = density_noise self.clamp_mode = clamp_mode self.white_back = white_back self.projection_mode = projection_mode
[文档] def get_value(self, target: str, render_kwargs: Optional[dict] = None) -> Any: """Get value of target field. Args: target (str): The key of the target field. render_kwargs (Optional[dict], optional): The input key word arguments dict. Defaults to None. Returns: Any: The default value of target field. """ if render_kwargs is None: return getattr(self, target) return render_kwargs.get(target, getattr(self, target))
[文档] def forward(self, planes: torch.Tensor, ray_origins: torch.Tensor, ray_directions: torch.Tensor, render_kwargs: Optional[dict] = None) -> Tuple[torch.Tensor]: """Render 2D RGB feature, weighed depth and weights with the passed triplane features and rays. 'weights' denotes `w` in Equation 5 of the NeRF's paper. Args: planes (torch.Tensor): The triplane features shape like (bz, 3, TriPlane_feat, TriPlane_res, TriPlane_res). ray_origins (torch.Tensor): The original of each ray to render, shape like (bz, NeRF_res * NeRF_res, 3). ray_directions (torch.Tensor): The direction vector of each ray to render, shape like (bz, NeRF_res * NeRF_res, 3). render_kwargs (Optional[dict], optional): The specific kwargs for rendering. Defaults to None. Returns: Tuple[torch.Tensor]: Renderer RGB feature, weighted depths and weights. """ ray_start = self.get_value('ray_start', render_kwargs) ray_end = self.get_value('ray_end', render_kwargs) box_warp = self.get_value('box_warp', render_kwargs) depth_resolution = self.get_value('depth_resolution', render_kwargs) depth_resolution_importance = self.get_value( 'depth_resolution_importance', render_kwargs) density_noise = self.get_value('density_noise', render_kwargs) if ray_start == ray_end == 'auto': ray_start, ray_end = get_ray_limits_box( ray_origins, ray_directions, box_side_length=box_warp) is_ray_valid = ray_end > ray_start if torch.any(is_ray_valid).item(): ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, depth_resolution) else: assert (isinstance(ray_start, float) and isinstance( ray_end, float)), ( '\'ray_start\' and \'ray_end\' must be both float type or ' f'both \'auto\'. But receive {ray_start} and {ray_end}.') assert ray_start < ray_end, ( '\'ray_start\' must less than \'ray_end\'.') # Create stratified depth samples depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, depth_resolution) batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape # Coarse Pass sample_coordinates = ( ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape( batch_size, -1, 3) out = self.neural_rendering(planes, sample_coordinates, density_noise, box_warp) colors_coarse = out['rgb'] densities_coarse = out['sigma'] colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) # Fine Pass N_importance = depth_resolution_importance if N_importance is not None and N_importance > 0: _, _, weights = self.volume_rendering(colors_coarse, densities_coarse, depths_coarse) depths_fine = self.sample_importance(depths_coarse, weights, N_importance) sample_coordinates = ( ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape( batch_size, -1, 3) out = self.neural_rendering(planes, sample_coordinates, density_noise, box_warp) colors_fine = out['rgb'] densities_fine = out['sigma'] colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) all_depths, all_colors, all_densities = self.unify_samples( depths_coarse, colors_coarse, densities_coarse, depths_fine, colors_fine, densities_fine) # Aggregate rgb_final, depth_final, weights = self.volume_rendering( all_colors, all_densities, all_depths) else: rgb_final, depth_final, weights = self.volume_rendering( colors_coarse, densities_coarse, depths_coarse) return rgb_final, depth_final, weights.sum(2)
[文档] def sample_stratified(self, ray_origins: torch.Tensor, ray_start: Union[float, torch.Tensor], ray_end: Union[float, torch.Tensor], depth_resolution: int) -> torch.Tensor: """Return depths of approximately uniformly spaced samples along rays. Args: ray_origins (torch.Tensor): The original of each ray, shape like (bz, NeRF_res * NeRF_res, 3). Only used to provide device and shape info. ray_start (Union[float, torch.Tensor]): The start position of rays. If a float is passed, all rays will have the same start distance. ray_end (Union[float, torch.Tensor]): The end position of rays. If a float is passed, all rays will have the same end distance. depth_resolution (int): Resolution of depth, as well as the number of points per ray. Returns: torch.Tensor: The sampled coarse depth shape like (bz, NeRF_res * NeRF_res, 1). """ N, M, _ = ray_origins.shape if isinstance(ray_start, torch.Tensor): # perform linspace for batch of tensor depths_coarse = linspace_batch(ray_start, ray_end, depth_resolution) depths_coarse = depths_coarse.permute(1, 2, 0, 3) depth_delta = (ray_end - ray_start) / (depth_resolution - 1) depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] else: depths_coarse = torch.linspace( ray_start, ray_end, depth_resolution, device=ray_origins.device) depths_coarse = depths_coarse.reshape(1, 1, depth_resolution, 1) depths_coarse = depths_coarse.repeat(N, M, 1, 1) depth_delta = (ray_end - ray_start) / (depth_resolution - 1) depths_coarse += torch.rand_like(depths_coarse) * depth_delta return depths_coarse
[文档] def neural_rendering(self, planes: torch.Tensor, sample_coordinates: float, density_noise: float, box_warp: float) -> dict: """Predict RGB features and densities of the coordinates by neural renderer model and the triplane input. Args: planes (torch.Tensor): Triplane feature shape like (bz, 3, TriPlane_feat, TriPlane_res, TriPlane_res). sample_coordinates (torch.Tensor): Coordinates of the sampling points, shape like (bz, N_depth * NeRF_res * NeRF_res, 1). density_noise (float): Strength of noise add to the predicted density. box_warp (float): The side length of the cube spanned by the triplanes. Returns: dict: A dict contains RGB features ('rgb') and densities ('sigma'). """ sampled_features = self.sample_from_planes( planes, sample_coordinates, box_warp=box_warp) out = self.decoder(sampled_features) if density_noise > 0: out['sigma'] += torch.randn_like(out['sigma']) * density_noise return out
[文档] def sample_from_planes(self, plane_features: torch.Tensor, coordinates: torch.Tensor, interp_mode: str = 'bilinear', box_warp: float = None) -> torch.Tensor: """Sample from feature from triplane feature with the passed coordinates of render points. Args: plane_features (torch.Tensor): The triplane feature. coordinates (torch.Tensor): The coordinates of points to render. interp_mode (str): The interpolation mode to sample feature from triplane. box_warp (float): The side length of the cube spanned by the triplanes. Returns: torch.Tensor: The sampled triplane feature of the render points. """ N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N * n_planes, C, H, W) coordinates = (2 / box_warp) * coordinates # NOTE: do not support change projection_mode for specific renderer, # use self.projection_mode projected_coordinates = self.project_onto_planes(coordinates) projected_coordinates = projected_coordinates[:, None, ...] output_features = torch.nn.functional.grid_sample( plane_features, projected_coordinates.float(), mode=interp_mode, padding_mode='zeros', align_corners=False) output_features = output_features.permute(0, 3, 2, 1).reshape( N, n_planes, M, C) return output_features
[文档] def project_onto_planes(self, coordinates: torch.Tensor) -> torch.Tensor: """Project 3D points to plane formed by coordinate axes. In this function, we use indexing operation to replace matrix multiplication to achieve higher calculation performance. In the original implementation, the mapping matrix is incorrect. Therefore we support users to define `projection_mode` to control projection behavior in the initialization function of :class:`~EG3DRenderer`. If you want to run inference with the official pretrained model, please remember to set `projection_mode = 'official'`. More information please refer to https://github.com/NVlabs/eg3d/issues/67. If the project mode `official`, the equivalent projection matrix is inverse matrix of: [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [1, 0, 0], [0, 1, 0]]] Otherwise, the equivalent projection matrix is inverse matrix of: [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]]] Args: coordinates (torch.Tensor): The coordinates of the render points. shape like (bz, NeRF_res * NeRF_res * N_depth, 3). Returns: torch.Tensor: The projected coordinates. """ N, _, _ = coordinates.shape xy_coord = coordinates[:, :, (0, 1)] # (bz, N_points, 3) xz_coord = coordinates[:, :, (0, 2)] # (bz, N_points, 3) if self.projection_mode.upper() == 'OFFICIAL': yz_coord = coordinates[:, :, (2, 0)] # actually zx_coord else: yz_coord = coordinates[:, :, (2, 1)] coord_proejcted = torch.cat([xy_coord, xz_coord, yz_coord], dim=0) # create a index list to release the following remapping: # [xy, xy, ..., xz, xz, ..., yz, yz, ...] -> [xy, xz, yz, ...] index = [] for n in range(N): index += [n, N + n, N * 2 + n] return coord_proejcted[index, ...]
[文档] def unify_samples(self, depths_c: torch.Tensor, colors_c: torch.Tensor, densities_c: torch.Tensor, depths_f: torch.Tensor, colors_f: torch.Tensor, densities_f: torch.Tensor) -> Tuple[torch.Tensor]: """Sort and merge coarse samples and fine samples. Args: depths_c (torch.Tensor): Coarse depths shape like (bz, NeRF_res * NeRF_res, N_depth, 1). colors_c (torch.Tensor): Coarse color features shape like (bz, NeRF_res * NeRF_res, N_depth, N_feat). densities_c (torch.Tensor): Coarse densities shape like (bz, NeRF_res * NeRF_res, N_depth, 1). depths_f (torch.Tensor): Fine depths shape like (bz, NeRF_res * NeRF_res, N_depth_fine, 1). colors_f (torch.Tensor): Fine colors features shape like (bz, NeRF_res * NeRF_res, N_depth_fine, N_feat). densities_f (torch.Tensor): Fine densities shape like (bz, NeRF_res * NeRF_res, N_depth_fine, 1). Returns: Tuple[torch.Tensor]: Unified depths, color features and densities. The third dimension of returns are `N_depth + N_depth_fine`. """ all_depths = torch.cat([depths_c, depths_f], dim=-2) all_colors = torch.cat([colors_c, colors_f], dim=-2) all_densities = torch.cat([densities_c, densities_f], dim=-2) _, indices = torch.sort(all_depths, dim=-2) all_depths = torch.gather(all_depths, -2, indices) all_colors = torch.gather( all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) return all_depths, all_colors, all_densities
[文档] def volume_rendering(self, colors: torch.Tensor, densities: torch.Tensor, depths: torch.Tensor) -> Tuple[torch.Tensor]: """Volume rendering. Args: colors (torch.Tensor): Color feature for each points. Shape like (bz, N_points, N_depth, N_feature). densities (torch.Tensor): Density for each points. Shape like (bz, N_points, N_depth, 1). depths (torch.Tensor): Depths for each points. Shape like (bz, N_points, N_depth, 1). Returns: Tuple[torch.Tensor]: A tuple of color feature `(bz, N_points, N_feature)`, weighted depth `(bz, N_points, 1)` and weight `(bz, N_points, N_depth-1, 1)`. """ deltas = depths[:, :, 1:] - depths[:, :, :-1] colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 if self.clamp_mode == 'softplus': # activation bias of -1 makes things initialize better densities_mid = F.softplus(densities_mid - 1) else: assert False, ( 'EG3DRenderer only supports \'softplus\' for \'clamp_mode\', ' f'but receive \'{self.clamp_mode}\'.') density_delta = densities_mid * deltas alpha = 1 - torch.exp(-density_delta) alpha_shifted = torch.cat( [torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] composite_rgb = torch.sum(weights * colors_mid, -2) weight_total = weights.sum(2) composite_depth = torch.sum(weights * depths_mid, -2) / weight_total # clip the composite to min/max range of depths if digit_version(TORCH_VERSION) < digit_version('1.8.0'): composite_depth[torch.isnan(composite_depth)] = float('inf') else: composite_depth = torch.nan_to_num(composite_depth, float('inf')) composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) if self.white_back: composite_rgb = composite_rgb + 1 - weight_total composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) return composite_rgb, composite_depth, weights
@torch.no_grad()
[文档] def sample_importance(self, z_vals: torch.Tensor, weights: torch.Tensor, N_importance: int) -> torch.Tensor: """Return depths of importance sampled points along rays. Args: z_vals (torch.Tensor): Coarse Z value (depth). Shape like (bz, N_points, N_depth, N_feature). weights (torch.Tensor): Weights of the coarse samples. Shape like (bz, N_points, N_depths-1, 1). N_importance (int): Number of samples to resample. """ batch_size, num_rays, samples_per_ray, _ = z_vals.shape z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) # -1 to account for loss of 1 sample in MipRayMarcher weights = weights.reshape(batch_size * num_rays, -1) # smooth weights as MipNeRF # max(weights[:-1], weights[1:]) weights = F.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) # 0.5 * (weights[:-1] + weights[1:]) weights = F.avg_pool1d(weights, 2, 1).squeeze() weights = weights + 0.01 # add resampling padding z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) importance_z_vals = inverse_transform_sampling(z_vals_mid, weights[:, 1:-1], N_importance).detach() importance_z_vals = importance_z_vals.reshape(batch_size, num_rays, N_importance, 1) return importance_z_vals
[文档]class EG3DDecoder(BaseModule): """Decoder for EG3D model. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. Defaults to 32. hidden_channels (int): The number of channels of hidden layer. Defaults to 64. lr_multiplier (float, optional): Equalized learning rate multiplier. Defaults to 1. rgb_padding (float): Padding for RGB output. Defaults to 0.001. """ def __init__(self, in_channels: int, out_channels: int = 32, hidden_channels: int = 64, lr_multiplier: float = 1, rgb_padding: float = 0.001): super().__init__() self.net = nn.Sequential( FullyConnectedLayer( in_channels, hidden_channels, lr_multiplier=lr_multiplier), nn.Softplus(), FullyConnectedLayer( hidden_channels, 1 + out_channels, lr_multiplier=lr_multiplier)) self.rgb_padding = rgb_padding
[文档] def forward(self, sampled_features: torch.Tensor) -> dict: """Forward function. Args: sampled_features (torch.Tensor): The sampled triplane feature for each points. Shape like (batch_size, xxx, xxx, n_ch). Returns: dict: A dict contains rgb feature and sigma value for each point. """ sampled_features = sampled_features.mean(1) N, M, C = sampled_features.shape feat = sampled_features.view(-1, C) feat = self.net(feat) feat = feat.view(N, M, -1) rgb = torch.sigmoid( feat[..., 1:]) * (1 + 2 * self.rgb_padding) - self.rgb_padding sigma = feat[..., 0:1] return {'rgb': rgb, 'sigma': sigma}
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.