Visu3d - DataclassArray (go/v3d-dataclass)#
If you’re new to v3d, please look at intro first.
v3d
was designed to be extensible and allow you to:
Create your custom primitives
Integrate your custom primitives with the rest of
v3d
As more people share their custom primitives, you can pick & choose existing primitives (camera models, COLMAP visualization,…) and compose them with your custom ones.
Installation#
We use same installation/imports as in the intro.
!pip install visu3d etils[ecolab] jax[cpu] tf-nightly tfds-nightly sunds
/bin/sh: line 1: pip: command not found
from __future__ import annotations
from etils.ecolab.lazy_imports import *
Without lazy_imports
, explicits imports are:
import dataclass_array as dca
import visu3d as v3d
DataclassArray API#
dca.DataclassArray
#
With dataclass_array
module, you can augment any @dataclasses.dataclass
to make them behave like an array (with slicing, reshaping,…).
For this:
Inherit from
dca.DataclassArray
Use
etils.array_types
to annotate your array fields (or explicitly usedca.field(shape=, dtype=)
instead ofdataclasses.field
from etils.array_types import FloatArray
class MyRay(dca.DataclassArray):
pos: FloatArray[..., 3]
dir: FloatArray[..., 3]
Note: Once PEP 681 – Data Class Transforms is supported by pytype, @dataclasses.dataclass(frozen=True)
won’t be required anymore.
dca.DataclassArray
provides:
All slicing/reshape (including einops support) operations
TensorFlow/Jax/Numpy conversions
ray = MyRay(pos=[0, 0, 0], dir=[1, 1, 1])
# Slicing, reshape ops
ray = ray.broadcast_to((2, 3))
ray = ray.reshape('h w -> w h')
ray = ray[..., 0]
# TensorFlow/Jax/Numpy conversions
ray = ray.as_jax()
ray.shape
(3,)
dataclass array fields are not restricted to xnp.Array
but can also be:
dca.DataclassArray
for nested dataclassesStatic fields, which won’t be batched
class MyScene(dca.DataclassArray):
exepriment_name: str
# TODO(epot): Support `shape=(None,)`
rays: MyRay = dca.field(shape=(3,), dtype=MyRay)
scene = MyScene(
exepriment_name='some_experiment',
rays=ray,
)
scene = scene.broadcast_to((5,)) # duplicate to 5 scenes
assert scene.rays.shape == (5, 3)
scene.exepriment_name # Static field is not affected by batching & cie
'some_experiment'
Vectorization#
@dca.vectorize_method
allow your dataclass method to automatically support batching:
Implement method as if
self.shape == ()
Decorate the method with
dca.vectorize_method
class MyRay(dca.DataclassArray):
pos: FloatArray[..., 3]
dir: FloatArray[..., 3]
@dca.vectorize_method
def get_xyz(self):
# Inside `@dca.vectorize_method` shape is always guarantee to be `()`
assert self.shape == ()
assert self.pos.shape == (3,)
x, y, z = self.pos
return x, y, z
ray = MyRay(pos=np.zeros((5, 6, 3)), dir=np.ones((5, 6, 3)))
x, y, z = ray.get_xyz()
x.shape
(5, 6)
@dca.vectorize_method
is similar to jax.vmap
but:
Only work on
dca.DataclassArray
methodsInstead of vectorizing a single axis,
@dca.vectorize_method
will vectorize over*self.shape
(not justself.shape[0]
). This is like ifvmap
was applied toself.flatten()
When multiple arguments, axis with dimension
1
are brodcasted.
For example, with __matmul__(self, x: T) -> T
:
() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b
Support v3d
features#
Your custom objects can easilly opt-in to support v3d
features by implementing the corresponding protocol:
Make your object visualizable (
my_obj.fig
,v3d.make_fig([my_obj])
): Implement the visualization protocolMake your object compatible with
v3d.Transform
(cam_from_world @ my_obj
): Implement the transform protocolMake your object compatible with 2d<>3d projection: Implement the camera projection protocol
Make your object compatible with
camera.render(my_obj)
: Implement the rendering protocol
Visualization protocol#
You can make your dataclass visualizable by implementing the protocol:
def make_traces(self) -> list[<plotly traces>]:
plotly
traces can be any go.Scatter3d
, go.Mesh3d
,…
Additionally, inheriting from v3d.Visualizable
add the .fig
property. v3d.DataclassArray
combine just a dca.DataclassArray
and v3d.Visualizable
.
class MyRay(dca.DataclassArray, v3d.Visualizable): # Could inherit from v3d.DataclassArray instead
pos: FloatArray[..., 3]
dir: FloatArray[..., 3]
def make_traces(self) -> list[plotly.basedatatypes.BaseTraceType]:
return v3d.plotly.make_lines_traces(
start=self.pos,
end=self.pos + self.dir,
end_marker='diamond',
)
ray = MyRay(pos=[0, 0, 0], dir=[-1, 1, 1])
ray.fig
The v3d.plotly
module is a small wrapper around plotly
to simplify building plotly.graph_object
traces (e.g to subsample with v3d.plotly.subsample
).
Note: You can make any Python object visualizable (not only dca.DataclassArray
) by inheriting from v3d.Visualizable
.
Transform protocol#
You can make your dataclass composable with v3d.Transform
by implementing the protocol:
def apply_transform(self, tr: v3d.Transform) -> Self:
Which will be called during tr @ my_obj
.
The protocol automatically support vectorization.
class MyRay(v3d.DataclassArray):
pos: FloatArray[..., 3]
dir: FloatArray[..., 3]
def apply_transform(self, tr: v3d.Transform) -> MyRay:
return self.replace(
pos=tr @ self.pos,
# apply_to_dir apply `tr.R` but not `tr.t`
dir=tr.apply_to_dir(self.dir),
)
tr = v3d.Transform.identity()
ray = MyRay(pos=[0, 0, 0], dir=[-1, 1, 1])
tr @ ray # Composing with identity is a no-op
MyRay(
pos=array([0., 0., 0.]),
dir=array([-1., 1., 1.]),
)
Camera projection protocol#
You can make your dataclass support pixel <> camera 3d coordinates projection by implementing the protocols:
def apply_px_from_cam(self, spec: camera_spec_lib.CameraSpec) -> MyPoint2d:
def apply_cam_from_px(self, spec: camera_spec_lib.CameraSpec) -> MyPoint3d:
Look at the v3d.Point3d
implementation for an example.
Rendering protocol#
Rendering protocol is not supported at the moment. Please open an issue if you need this feature.