Shortcuts

Source code for mmagic.models.editors.eg3d.ray_sampler

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch


[docs]def sample_rays(cam2world: torch.Tensor, intrinsics: torch.Tensor, resolution: int) -> Tuple[torch.Tensor]: """Sample origin and direction vectors of rays with passed camera-to-world matrix and intrinsics matrix. Noted that skew coefficient is not considered in this function. Args: cam2world (torch.Tensor): The camera-to-world matrix in homogeneous coordinates. Shape like (bz, 4, 4). intrinsics (torch.Tensor): The intrinsic matrix. Shape like (bz, 3, 3). resolution (int): The expect resolution of the render output. Returns: Tuple[torch.Tensor]: Origins and view directions for rays. Both shape like (bz, resolution^2, 3) """ batch_size, n_points = cam2world.shape[0], resolution**2 cam_in_world = cam2world[:, :3, 3] fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] device = cam2world.device # torch.meshgrid has been modified in 1.10.0 (compatibility with previous # versions), and will be further modified in 1.12 (Breaking Change) if 'indexing' in torch.meshgrid.__code__.co_varnames: u, v = torch.meshgrid( torch.arange(resolution, dtype=torch.float32, device=device), torch.arange(resolution, dtype=torch.float32, device=device), indexing='ij') else: u, v = torch.meshgrid( torch.arange(resolution, dtype=torch.float32, device=device), torch.arange(resolution, dtype=torch.float32, device=device)) uv = torch.stack([u, v]) uv = uv * (1. / resolution) + (0.5 / resolution) uv = uv.flip(0).reshape(2, -1).transpose(1, 0) uv = uv.unsqueeze(0).repeat(cam2world.shape[0], 1, 1) x_cam = uv[:, :, 0].view(batch_size, -1) y_cam = uv[:, :, 1].view(batch_size, -1) z_cam = torch.ones((batch_size, n_points), device=cam2world.device) x_lift = (x_cam - cx.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam points_in_cam = torch.stack( (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) # camera coordinate to world coordinate points_in_world = torch.bmm(cam2world, points_in_cam.permute(0, 2, 1)) points_in_world = points_in_world.permute(0, 2, 1)[:, :, :3] ray_dirs = points_in_world - cam_in_world[:, None, :] ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) ray_origins = cam_in_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) return ray_origins, ray_dirs