Visu3d - DataclassArray (go/v3d-dataclass)#

If you’re new to v3d, please look at intro first.

v3d was designed to be extensible and allow you to:

  • Create your custom primitives

  • Integrate your custom primitives with the rest of v3d

As more people share their custom primitives, you can pick & choose existing primitives (camera models, COLMAP visualization,…) and compose them with your custom ones.

Installation#

We use same installation/imports as in the intro.

!pip install visu3d etils[ecolab] jax[cpu] tf-nightly tfds-nightly sunds
/bin/sh: line 1: pip: command not found
from __future__ import annotations
from etils.ecolab.lazy_imports import *

Without lazy_imports, explicits imports are:

import dataclass_array as dca
import visu3d as v3d

DataclassArray API#

dca.DataclassArray#

With dataclass_array module, you can augment any @dataclasses.dataclass to make them behave like an array (with slicing, reshaping,…).

For this:

  • Inherit from dca.DataclassArray

  • Use etils.array_types to annotate your array fields (or explicitly use dca.field(shape=, dtype=) instead of dataclasses.field

from etils.array_types import FloatArray

class MyRay(dca.DataclassArray):
  pos: FloatArray[..., 3]
  dir: FloatArray[..., 3]

Note: Once PEP 681 – Data Class Transforms is supported by pytype, @dataclasses.dataclass(frozen=True) won’t be required anymore.

dca.DataclassArray provides:

  • All slicing/reshape (including einops support) operations

  • TensorFlow/Jax/Numpy conversions

ray = MyRay(pos=[0, 0, 0], dir=[1, 1, 1])

# Slicing, reshape ops
ray = ray.broadcast_to((2, 3))
ray = ray.reshape('h w -> w h')
ray = ray[..., 0]

# TensorFlow/Jax/Numpy conversions
ray = ray.as_jax()

ray.shape
(3,)

dataclass array fields are not restricted to xnp.Array but can also be:

  • dca.DataclassArray for nested dataclasses

  • Static fields, which won’t be batched

class MyScene(dca.DataclassArray):
  exepriment_name: str
  # TODO(epot): Support `shape=(None,)`
  rays: MyRay = dca.field(shape=(3,), dtype=MyRay)


scene = MyScene(
    exepriment_name='some_experiment',
    rays=ray,
)
scene = scene.broadcast_to((5,))  # duplicate to 5 scenes
assert scene.rays.shape == (5, 3)
scene.exepriment_name  # Static field is not affected by batching & cie
'some_experiment'

Vectorization#

@dca.vectorize_method allow your dataclass method to automatically support batching:

  1. Implement method as if self.shape == ()

  2. Decorate the method with dca.vectorize_method

class MyRay(dca.DataclassArray):
  pos: FloatArray[..., 3]
  dir: FloatArray[..., 3]

  @dca.vectorize_method
  def get_xyz(self):
    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
    assert self.shape == ()
    assert self.pos.shape == (3,)

    x, y, z = self.pos
    return x, y, z


ray = MyRay(pos=np.zeros((5, 6, 3)), dir=np.ones((5, 6, 3)))
x, y, z = ray.get_xyz()
x.shape
(5, 6)

@dca.vectorize_method is similar to jax.vmap but:

  • Only work on dca.DataclassArray methods

  • Instead of vectorizing a single axis, @dca.vectorize_method will vectorize over *self.shape (not just self.shape[0]). This is like if vmap was applied to self.flatten()

  • When multiple arguments, axis with dimension 1 are brodcasted.

For example, with __matmul__(self, x: T) -> T:

() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b

Support v3d features#

Your custom objects can easilly opt-in to support v3d features by implementing the corresponding protocol:

  • Make your object visualizable (my_obj.fig, v3d.make_fig([my_obj])): Implement the visualization protocol

  • Make your object compatible with v3d.Transform (cam_from_world @ my_obj): Implement the transform protocol

  • Make your object compatible with 2d<>3d projection: Implement the camera projection protocol

  • Make your object compatible with camera.render(my_obj): Implement the rendering protocol

Visualization protocol#

You can make your dataclass visualizable by implementing the protocol:

def make_traces(self) -> list[<plotly traces>]:

plotly traces can be any go.Scatter3d, go.Mesh3d,…

Additionally, inheriting from v3d.Visualizable add the .fig property. v3d.DataclassArray combine just a dca.DataclassArray and v3d.Visualizable.

class MyRay(dca.DataclassArray, v3d.Visualizable):  # Could inherit from v3d.DataclassArray instead
  pos: FloatArray[..., 3]
  dir: FloatArray[..., 3]

  def make_traces(self) -> list[plotly.basedatatypes.BaseTraceType]:
    return v3d.plotly.make_lines_traces(
        start=self.pos,
        end=self.pos + self.dir,
        end_marker='diamond',
    )

ray = MyRay(pos=[0, 0, 0], dir=[-1, 1, 1])
ray.fig

The v3d.plotly module is a small wrapper around plotly to simplify building plotly.graph_object traces (e.g to subsample with v3d.plotly.subsample).

Note: You can make any Python object visualizable (not only dca.DataclassArray) by inheriting from v3d.Visualizable.

Transform protocol#

You can make your dataclass composable with v3d.Transform by implementing the protocol:

def apply_transform(self, tr: v3d.Transform) -> Self:

Which will be called during tr @ my_obj.

The protocol automatically support vectorization.

class MyRay(v3d.DataclassArray):
  pos: FloatArray[..., 3]
  dir: FloatArray[..., 3]

  def apply_transform(self, tr: v3d.Transform) -> MyRay:
    return self.replace(
        pos=tr @ self.pos,
        # apply_to_dir apply `tr.R` but not `tr.t`
        dir=tr.apply_to_dir(self.dir),
    )

tr = v3d.Transform.identity()
ray = MyRay(pos=[0, 0, 0], dir=[-1, 1, 1])
tr @ ray  # Composing with identity is a no-op
MyRay(
    pos=array([0., 0., 0.]),
    dir=array([-1.,  1.,  1.]),
)

Camera projection protocol#

You can make your dataclass support pixel <> camera 3d coordinates projection by implementing the protocols:

def apply_px_from_cam(self, spec: camera_spec_lib.CameraSpec) -> MyPoint2d:

def apply_cam_from_px(self, spec: camera_spec_lib.CameraSpec) -> MyPoint3d:

Look at the v3d.Point3d implementation for an example.

Rendering protocol#

Rendering protocol is not supported at the moment. Please open an issue if you need this feature.