Visu3d - Intro (go/v3d-intro)#

3D geometry is hard. Visu3d makes it easier by providing:

  • Easy-to-use & powerful primitives

  • Visualization as first-class citizen

  • Same code works everywhere (native TensorFlow, Jax, numpy support)

Imports#

Let’s install the required deps:

  • TF, Jax are optional dependencies so have to be installed separately.

  • sunds and tfds are used to load the dataset

  • etils[colab] install some colab utils

!pip install visu3d etils[ecolab] jax[cpu] tf-nightly tfds-nightly sunds

Visu3d is imported with: import visu3d as v3d.

from etils import ecolab  # Util to auto-display images

# Jax, Numpy, Tensorflow APIs
import jax.numpy as jnp
import numpy as np
import tensorflow.experimental.numpy as tnp

# Dataset utils (to test on real data)
import sunds

# Visu3d import
import visu3d as v3d

On colab, we use from etils.lazy_imports import * to lazily import everything on first usage.

from etils.lazy_imports import *
ecolab.auto_plot_array()  # Display np.array as images/video
tnp.experimental_enable_numpy_behavior()  # Activate TF numpy behavior
Display big np/tf/jax arrays as image for nicer IPython display

Ray#

Batching & slicing#

Everything in v3d is a v3d.DataclassArray. Dataclass arrays are like dataclasses but support numpy-like indexing, vectorization,…

For example, let’s start with a simple ray:

ray = v3d.Ray(pos=[0, 0, 0], dir=[1, -1, 1])
ray
Ray(
    pos=array([0., 0., 0.], dtype=float32),
    dir=array([ 1., -1.,  1.], dtype=float32),
)

Note that inputs are automatically casted to arrays of the correct dtype.

All v3d objects have a .fig property to visualize them interactivelly:

ray.fig

Because v3d.DataclassArray behave like numpy arrays, they support batching of arbitrary dimention (e.g. (batch_size, h, w),…).

Let’s batch 3 rays together:

ray = v3d.Ray(
    pos=jnp.zeros((3, 3)),  # Note we can use jnp, tf or np interchangeably
    dir=jnp.eye(3),
)
ray.fig

v3d.DataclassArray support all standard numpy transformations (ray[..., :], ray.flatten(), ray.broadcast_to((3, 1)),…)

print(f'ray={ray}')
print(f'ray.shape is {ray.shape}')
print(f'ray.pos.shape is {ray.pos.shape}')
print(f'ray.dir.shape is {ray.dir.shape}')
print('\nNumpy indexing works as expected:')
print(f'ray[..., 0]={ray[..., 0]}')
print('\nShape manipulation:')
print(f'ray.reshape((3, 1).shape == {ray.reshape((3, 1)).shape}')
print(f'ray.flatten().shape == {ray.flatten().shape}')
ray=Ray(
    pos=DeviceArray([[0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.]], dtype=float32),
    dir=DeviceArray([[1., 0., 0.],
                 [0., 1., 0.],
                 [0., 0., 1.]], dtype=float32),
)
ray.shape is (3,)
ray.pos.shape is (3, 3)
ray.dir.shape is (3, 3)

Numpy indexing works as expected:
ray[..., 0]=Ray(
    pos=DeviceArray([0., 0., 0.], dtype=float32),
    dir=DeviceArray([1., 0., 0.], dtype=float32),
)

Shape manipulation:
ray.reshape((3, 1).shape == (3, 1)
ray.flatten().shape == (3,)

With numpy slicing, it become trivial to apply numpy masking/filtering:

ray = ray[ray.norm() > 0]  # Filter rays with |dir| == 0

v3d also has native einops support:

ray = ray[..., None, None]
ray = ray.reshape('b h w -> (b h w)')

TensorFlow, Jax, Numpy support#

v3d support natively tf.Tensor, jnp.array without any code change!

ray = ray.as_tf()  # Convert to tf.Tensor (`.as_jax()`, `.as_np()` also exists)
ray[-1]
Ray(
    pos=<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
    dir=<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 1.], dtype=float32)>,
)

You can get the current numpy module (np, jnp, tnp) used by the dataclass with .xnp:

ray.xnp.__name__
'tensorflow.experimental.numpy'

Visualization#

.fig is just an alias for v3d.make_fig. v3d.make_fig allow you to diplay multiple objects together:

v3d.make_fig([
    # Addition `+` translate the rays
    ray + [3, 1, 2.5],

    # Average between the y and z axis
    ray[1:].mean().normalize().scale_dir(4.),

    # Point clouds (..., 3)
    np.random.default_rng(0).random((5, 5, 3)) * 4,
])

To reduce verbosity, you can also auto-display tuple[Visualizable,...] directly as figure using v3d.auto_plot_figs

v3d.auto_plot_figs()

ray, ray[1:].mean()
Display `tuple[v3d.Visualizable, ...]` as figure

Camera#

Parameters#

Cameras are defined by 2 attributes:

  • spec: Camera intrinsics parameters (v3d.CameraSpec)

  • cam2world: Camera position in the world (v3d.Transform of camera to world coordinates)

H, W = 124, 256

# CameraSpec specifications
spec = v3d.PinholeCamera.from_focal(
    resolution=(H, W),
    focal_in_px=120,
)

# Create a Camera looking at the center
cam = v3d.Camera.from_look_at(
    spec=spec,
    pos=[1, 6, 3.5],
    target=[0, 0, 0],
)
cam.fig

Trajectory#

An easy way to generate camera trajectories is to use v3d.Ray to store the camera origin and direction (assuming the camera is parallel to the ground).

Let’s generate some camera trajectory:

# Trajectory parameters
N_STEPS = 100
RADIUS = 7.
MEAN_Z = 3.
NUM_UP_DOWN = 3

# Generate `N_STEPS` rays looking at the center
t = np.arange(N_STEPS)

pos = np.stack([
  np.cos(2 * np.pi * t / N_STEPS) * RADIUS,
  np.sin(2 * np.pi * t / N_STEPS) * RADIUS,
  MEAN_Z + np.sin(NUM_UP_DOWN * 2 * np.pi * t / N_STEPS),
], axis=-1)

trajectory = v3d.Ray.from_look_at(
  pos=pos,
  target=[0, 0, 0],
)
trajectory = trajectory.normalize()
trajectory.fig

Create the associated cameras:

Because v3d.Camera is a v3d.DataclassArray, it support batching too, so all 100 cameras are stored in a single object.

cams = v3d.Camera.from_look_at(
    spec=spec,
    pos=trajectory.pos,
    target=trajectory.end,
)
assert cams.shape == (N_STEPS,)  # <= 100 cameras batched together

Generate the associated rays:

rays = cams.rays()  # Generate the rays for all cameras
assert rays.shape == (N_STEPS, H, W)

Let’s play with a concrete use-case#

Let’s see how v3d can be used in a concrete use-case.

Dataset#

Let’s load a dataset. Here we load a sunds dataset but this could be replaced by any dataset of your choice.

builder = sunds.builder(
    'kubric/multi_shapenet_conditional',
    data_dir='gs://kubric-public/tfds',
)
with ecolab.collapse('Dataset info:'):
  print(builder.frame_builder.info)
Dataset info:
tfds.core.DatasetInfo(
    name='kubric_frames',
    full_name='kubric_frames/multi_shapenet_conditional/2.8.0',
    description="""
    
    """,
    config_description="""
    Conditional MultiShapenet dataset.
    """,
    homepage='http://n/a',
    data_path='gs://kubric-public/tfds/kubric_frames/multi_shapenet_conditional/2.8.0',
    file_format=tfrecord,
    download_size=Unknown size,
    dataset_size=871.35 GiB,
    features=FeaturesDict({
        'cameras': FeaturesDict({
            'camera_0': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_1': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_2': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_3': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_4': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_5': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_6': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_7': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_8': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
            'camera_9': FeaturesDict({
                'category_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=55),
                'color_image': Image(shape=(128, 128, 3), dtype=tf.uint8),
                'depth_image': Image(shape=(128, 128, 1), dtype=tf.float32),
                'extrinsics': FeaturesDict({
                    'R': Tensor(shape=(3, 3), dtype=tf.float32),
                    't': Tensor(shape=(3,), dtype=tf.float32),
                }),
                'instance_image': LabeledImage(shape=(128, 128, 1), dtype=tf.uint8, num_classes=None),
                'intrinsics': FeaturesDict({
                    'K': Tensor(shape=(3, 3), dtype=tf.float32),
                    'distortion': FeaturesDict({
                        'radial': Tensor(shape=(3,), dtype=tf.float32),
                        'tangential': Tensor(shape=(2,), dtype=tf.float32),
                    }),
                    'image_height': tf.int32,
                    'image_width': tf.int32,
                    'type': Text(shape=(), dtype=tf.string),
                }),
            }),
        }),
        'frame_name': Text(shape=(), dtype=tf.string),
        'pose': FeaturesDict({
            'R': Tensor(shape=(3, 3), dtype=tf.float32),
            't': Tensor(shape=(3,), dtype=tf.float32),
        }),
        'scene_name': Text(shape=(), dtype=tf.string),
    }),
    supervised_keys=None,
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=104263, num_shards=1024>,
        'train': <SplitInfo num_examples=1026570, num_shards=1024>,
        'val': <SplitInfo num_examples=103090, num_shards=1024>,
    },
    citation="""""",
)

Load a single scene, stacking all camera in a single example (shape shape=(num_cams, h, w, ...)).

ds = builder.as_dataset(
    split='train[:1]',
    task=sunds.tasks.Nerf(
        yield_mode='stacked',
        additional_camera_specs={'depth_image'},
    ),
)
ex, = ds.as_numpy_iterator()

ds.element_spec  # Shape/dtype structure
{'color_image': TensorSpec(shape=(10, 128, 128, 3), dtype=tf.uint8, name=None),
 'depth_image': TensorSpec(shape=(10, 128, 128, 1), dtype=tf.float32, name=None),
 'ray_directions': TensorSpec(shape=(10, 128, 128, 3), dtype=tf.float32, name=None),
 'ray_origins': TensorSpec(shape=(10, 128, 128, 3), dtype=tf.float32, name=None)}

Scene inspection#

Let’s extract the color, depth and rays of the dataset.

With ecolab.auto_plot_array(), 4d arrays are auto-displayed as images:

rgb = ex['color_image']
rgb
# The original Kubric dataset has bug with infinit depth values, so manually clip values
depth = np.array(ex['depth_image'])
depth[depth > 100] = 0
depth
rays = v3d.Ray(
    pos=ex['ray_origins'],
    dir=ex['ray_directions'],
)
rays.fig

Note that in the above plot, we try to display 10*128*128 == 163_840 rays. v3d dynamically & deterministically sub-samples rays to keep rendering time reasonable.

With the ray and depth, we can project the point cloud to the scene:

We could directly plot the (..., 3) np.ndarray, but we can also use v3d.Point3d to store color along the point cloud. Both rays & images have the same batch shape (num_cam, h, w, ...), so no transformation required.

point_cloud = v3d.Point3d(
    p=rays.scale_dir(depth).end,  # Project depth
    rgb=ex['color_image'],
)
point_cloud.fig

Rendering & vectorization#

Let’s visualize everything together: The point cloud, the original camera, the camera trajectory:

print(f'point_cloud.shape={point_cloud.shape}')
print(f'rays.shape={rays.shape}')
print(f'cams.shape={cams.shape}')

point_cloud, rays, cams
point_cloud.shape=(10, 128, 128)
rays.shape=(10, 128, 128)
cams.shape=(100,)

We can now render our trajectory by projecting the point cloud back to the cameras:

imgs = cams.render(point_cloud[None, ...])
imgs

A few explaination on what happened here: Because the 100 cameras are batched together, we only need a single call to render the 100 frames.

v3d support vectorization, which mean calling a function on a batched object is similar to wrapping it inside jax.vmap.

Here our cameras have shape cams.shape == (100,), so we need to call cams.render on a point_cloud.shape == (100, ...). v3d auto-broadcast dim of shape 1, so point_cloud[None, ...] create point_cloud.shape == (1, ...) which is then broadcasted (so each camera render the frame on the full point cloud).

  • () @ (...,) -> (H, W)

  • (100,) @ (100, ...) -> (100, H, W)

  • (100,) @ (1, ,...) -> (100, H, W)

More info on vectorization in the recipes tutorial.

Customizing the plots (rename, subsampling,…)#

Figure can be customized. Options can be set:

v3d.fig_config.show_zero = False
  • At the individual figure level (v3d.make_fig()):

fig = v3d.make_fig(cams, show_zero=False)
  • At the object level (.replace_fig_config():

rays = rays.replace_fig_config(name='My rays', num_samples=None)

Global options (see code):

  • show_zero: Whether to always show the (0, 0, 0) origin

  • num_samples_xxx (with _xxx to be _ray, _point3d,…). Globally control the subsampling for the matching class (v3d.Ray, v3d.Point3d,…).

  • cam_scale: Factor to scale up/down of the cameras

Local options:

  • name: Object name (default to Ray, Camera,…)

  • num_samples: When too many objects are displayed at once, they will be dynamically subsampled to this value (None to disabling)

  • (for cameras) scale: Factor to scale up/down of the cameras

What’s next#

If you’re ready to go to the next level, you can switch to the