Skip to content

Segment anything 3d

Sam3_3D_ObjectsPipelineSingleton

Singleton to cache the heavy 3D pipeline initialization.

Source code in inference/models/sam3_3d/segment_anything_3d.py
237
238
239
240
241
242
243
244
245
246
247
248
249
class Sam3_3D_ObjectsPipelineSingleton:
    """Singleton to cache the heavy 3D pipeline initialization."""

    _instances = weakref.WeakValueDictionary()
    _lock = Lock()

    def __new__(cls, config_key: str):
        with cls._lock:
            if config_key not in cls._instances:
                instance = super().__new__(cls)
                instance.config_key = config_key
                cls._instances[config_key] = instance
            return cls._instances[config_key]

SegmentAnything3_3D_Objects

Bases: RoboflowCoreModel

Source code in inference/models/sam3_3d/segment_anything_3d.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
class SegmentAnything3_3D_Objects(RoboflowCoreModel):
    def __init__(
        self,
        *args,
        model_id: str = "sam3-3d-objects",
        torch_compile: bool = False,
        compile_res: int = 518,
        **kwargs,
    ):
        super().__init__(model_id=model_id, **kwargs)

        self.cache_dir = Path(get_cache_dir(model_id=self.endpoint))

        tdfy_dir = files(tdfy.sam3d_v1)
        pipeline_config_path = tdfy_dir / "checkpoints_configs" / "pipeline.yaml"
        moge_checkpoint_path = self.cache_dir / "moge-vitl.pth"
        ss_generator_checkpoint_path = self.cache_dir / "ss_generator.ckpt"
        slat_generator_checkpoint_path = self.cache_dir / "slat_generator.ckpt"
        ss_decoder_checkpoint_path = self.cache_dir / "ss_decoder.ckpt"
        slat_decoder_checkpoint_path = self.cache_dir / "slat_decoder_gs.ckpt"
        slat_decodergs4_checkpoint_path = self.cache_dir / "slat_decoder_gs_4.ckpt"
        slat_decoder_mesh_checkpoint_path = self.cache_dir / "slat_decoder_mesh.pt"
        dinov2_ckpt_path = self.cache_dir / "dinov2_vitl14_reg4_pretrain.pth"

        config_key = f"{DEVICE}_{pipeline_config_path}"
        singleton = Sam3_3D_ObjectsPipelineSingleton(config_key)

        if not hasattr(singleton, "pipeline"):
            self.pipeline_config = OmegaConf.load(str(pipeline_config_path))
            self.pipeline_config["device"] = DEVICE
            self.pipeline_config["workspace_dir"] = str(tdfy_dir)
            self.pipeline_config["compile_model"] = torch_compile
            self.pipeline_config["compile_res"] = compile_res
            self.pipeline_config["depth_model"]["model"][
                "pretrained_model_name_or_path"
            ] = str(moge_checkpoint_path)
            self.pipeline_config["ss_generator_ckpt_path"] = str(
                ss_generator_checkpoint_path
            )
            self.pipeline_config["slat_generator_ckpt_path"] = str(
                slat_generator_checkpoint_path
            )
            self.pipeline_config["ss_decoder_ckpt_path"] = str(
                ss_decoder_checkpoint_path
            )
            self.pipeline_config["slat_decoder_gs_ckpt_path"] = str(
                slat_decoder_checkpoint_path
            )
            self.pipeline_config["slat_decoder_gs_4_ckpt_path"] = str(
                slat_decodergs4_checkpoint_path
            )
            self.pipeline_config["slat_decoder_mesh_ckpt_path"] = str(
                slat_decoder_mesh_checkpoint_path
            )
            self.pipeline_config["dinov2_ckpt_path"] = str(dinov2_ckpt_path)
            singleton.pipeline = instantiate(self.pipeline_config)

        # Reference the singleton's pipeline
        self.pipeline = singleton.pipeline
        self._state_lock = Lock()

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["environment.json"].
        """
        return [
            "moge-vitl.pth",
            "ss_generator.ckpt",
            "slat_generator.ckpt",
            "ss_decoder.ckpt",
            "slat_decoder_gs.ckpt",
            "slat_decoder_gs_4.ckpt",
            "slat_decoder_mesh.pt",
        ]

    def download_model_from_roboflow_api(self) -> None:
        """Override parent method to use streaming downloads for large SAM3_3D model files."""
        lock_dir = MODEL_CACHE_DIR + "/_file_locks"
        os.makedirs(lock_dir, exist_ok=True)
        lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
        lock = FileLock(lock_file, timeout=120)
        with lock:
            api_data = get_roboflow_model_data(
                api_key=self.api_key,
                model_id="sam3-3d-weights-vc6vz/1",
                endpoint_type=ModelEndpointType.ORT,
                device_id=self.device_id,
            )["ort"]
            if "weights" not in api_data:
                raise ModelArtefactError(
                    f"`weights` key not available in Roboflow API response while downloading model weights."
                )
            for weights_url_key in api_data["weights"]:
                weights_url = api_data["weights"][weights_url_key]
                filename = weights_url.split("?")[0].split("/")[-1]
                stream_url_to_cache(
                    url=weights_url,
                    filename=filename,
                    model_id=self.endpoint,
                )

    def infer_from_request(
        self, request: Sam3_3D_Objects_InferenceRequest
    ) -> Sam3_3D_Objects_Response:
        with self._state_lock:
            t1 = perf_counter()
            raw_result = self.create_3d(**request.dict())
            inference_time = perf_counter() - t1
            return convert_3d_objects_result_to_api_response(
                raw_result=raw_result,
                inference_time=inference_time,
            )

    def create_3d(
        self,
        image: Optional[InferenceRequestImage],
        mask_input: Optional[Any] = None,
        **kwargs,
    ):
        """
        Generate 3D from image and mask(s).

        Args:
            image: Input image
            mask_input: Mask in any supported format:
                - np.ndarray (H,W) or (N,H,W): Binary mask(s)
                - List[float]: COCO polygon [x1,y1,x2,y2,...]
                - List[List[float]]: Multiple polygons
                - Dict with 'counts'/'size': RLE mask
                - List[Dict]: Multiple RLE masks
        """
        with torch.inference_mode():
            if image is None or mask_input is None:
                raise ValueError("Must provide image and mask!")

            image_np = load_image_rgb(image)
            if image_np.dtype != np.uint8:
                if image_np.max() <= 1:
                    image_np = (image_np * 255).astype(np.uint8)
                else:
                    image_np = image_np.astype(np.uint8)
            image_shape = (image_np.shape[0], image_np.shape[1])

            if _is_single_mask_input(mask_input):
                masks = [convert_mask_to_binary(mask_input, image_shape)]
            elif isinstance(mask_input, np.ndarray) and mask_input.ndim == 3:
                masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]
            else:
                masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]

            outputs = []
            for mask in masks:
                result = self.pipeline.run(image=image_np, mask=mask)
                outputs.append(result)

            if len(outputs) == 1:
                result = outputs[0]
                scene_gs = ready_gaussian_for_video_rendering(result["gs"])
                return {
                    "gs": scene_gs,
                    "glb": result["glb"],
                    "objects": outputs,
                }
            else:
                scene_gs = make_scene(*outputs)
                scene_gs = ready_gaussian_for_video_rendering(scene_gs)
                scene_gs = apply_gaussian_view_correction(scene_gs)
                scene_glb = make_scene_glb(*outputs)
                return {
                    "gs": scene_gs,
                    "glb": scene_glb,
                    "objects": outputs,
                }

create_3d(image, mask_input=None, **kwargs)

Generate 3D from image and mask(s).

Parameters:

Name Type Description Default
image Optional[InferenceRequestImage]

Input image

required
mask_input Optional[Any]

Mask in any supported format: - np.ndarray (H,W) or (N,H,W): Binary mask(s) - List[float]: COCO polygon [x1,y1,x2,y2,...] - List[List[float]]: Multiple polygons - Dict with 'counts'/'size': RLE mask - List[Dict]: Multiple RLE masks

None
Source code in inference/models/sam3_3d/segment_anything_3d.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def create_3d(
    self,
    image: Optional[InferenceRequestImage],
    mask_input: Optional[Any] = None,
    **kwargs,
):
    """
    Generate 3D from image and mask(s).

    Args:
        image: Input image
        mask_input: Mask in any supported format:
            - np.ndarray (H,W) or (N,H,W): Binary mask(s)
            - List[float]: COCO polygon [x1,y1,x2,y2,...]
            - List[List[float]]: Multiple polygons
            - Dict with 'counts'/'size': RLE mask
            - List[Dict]: Multiple RLE masks
    """
    with torch.inference_mode():
        if image is None or mask_input is None:
            raise ValueError("Must provide image and mask!")

        image_np = load_image_rgb(image)
        if image_np.dtype != np.uint8:
            if image_np.max() <= 1:
                image_np = (image_np * 255).astype(np.uint8)
            else:
                image_np = image_np.astype(np.uint8)
        image_shape = (image_np.shape[0], image_np.shape[1])

        if _is_single_mask_input(mask_input):
            masks = [convert_mask_to_binary(mask_input, image_shape)]
        elif isinstance(mask_input, np.ndarray) and mask_input.ndim == 3:
            masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]
        else:
            masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]

        outputs = []
        for mask in masks:
            result = self.pipeline.run(image=image_np, mask=mask)
            outputs.append(result)

        if len(outputs) == 1:
            result = outputs[0]
            scene_gs = ready_gaussian_for_video_rendering(result["gs"])
            return {
                "gs": scene_gs,
                "glb": result["glb"],
                "objects": outputs,
            }
        else:
            scene_gs = make_scene(*outputs)
            scene_gs = ready_gaussian_for_video_rendering(scene_gs)
            scene_gs = apply_gaussian_view_correction(scene_gs)
            scene_glb = make_scene_glb(*outputs)
            return {
                "gs": scene_gs,
                "glb": scene_glb,
                "objects": outputs,
            }

download_model_from_roboflow_api()

Override parent method to use streaming downloads for large SAM3_3D model files.

Source code in inference/models/sam3_3d/segment_anything_3d.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def download_model_from_roboflow_api(self) -> None:
    """Override parent method to use streaming downloads for large SAM3_3D model files."""
    lock_dir = MODEL_CACHE_DIR + "/_file_locks"
    os.makedirs(lock_dir, exist_ok=True)
    lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
    lock = FileLock(lock_file, timeout=120)
    with lock:
        api_data = get_roboflow_model_data(
            api_key=self.api_key,
            model_id="sam3-3d-weights-vc6vz/1",
            endpoint_type=ModelEndpointType.ORT,
            device_id=self.device_id,
        )["ort"]
        if "weights" not in api_data:
            raise ModelArtefactError(
                f"`weights` key not available in Roboflow API response while downloading model weights."
            )
        for weights_url_key in api_data["weights"]:
            weights_url = api_data["weights"][weights_url_key]
            filename = weights_url.split("?")[0].split("/")[-1]
            stream_url_to_cache(
                url=weights_url,
                filename=filename,
                model_id=self.endpoint,
            )

get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["environment.json"].

Source code in inference/models/sam3_3d/segment_anything_3d.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["environment.json"].
    """
    return [
        "moge-vitl.pth",
        "ss_generator.ckpt",
        "slat_generator.ckpt",
        "ss_decoder.ckpt",
        "slat_decoder_gs.ckpt",
        "slat_decoder_gs_4.ckpt",
        "slat_decoder_mesh.pt",
    ]

apply_gaussian_view_correction(scene_gs)

Apply view correction to Gaussian scene to match GLB orientation. Used for combined scene PLY.

Source code in inference/models/sam3_3d/segment_anything_3d.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def apply_gaussian_view_correction(scene_gs):
    """
    Apply view correction to Gaussian scene to match GLB orientation.
    Used for combined scene PLY.
    """
    xyz = scene_gs.get_xyz
    device = xyz.device
    dtype = xyz.dtype

    R_view_zup = torch.tensor(
        [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
        device=device,
        dtype=dtype,
    )

    new_xyz = xyz @ R_view_zup
    scene_gs.from_xyz(new_xyz)

    q_correction = matrix_to_quaternion(R_view_zup.unsqueeze(0)).squeeze(0)
    old_rotations = scene_gs.get_rotation
    new_rotations = quaternion_multiply(
        q_correction.unsqueeze(0).expand(old_rotations.shape[0], -1), old_rotations
    )
    scene_gs.from_rotation(new_rotations)

    return scene_gs

convert_mask_to_binary(mask_input, image_shape)

Convert polygon, RLE, or binary mask to binary mask (H, W) with values 0/255.

Source code in inference/models/sam3_3d/segment_anything_3d.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def convert_mask_to_binary(mask_input: Any, image_shape: Tuple[int, int]) -> np.ndarray:
    """Convert polygon, RLE, or binary mask to binary mask (H, W) with values 0/255."""
    height, width = image_shape

    if isinstance(mask_input, np.ndarray):
        return _normalize_binary_mask(mask_input, image_shape)

    if isinstance(mask_input, Image.Image):
        return _normalize_binary_mask(np.array(mask_input.convert("L")), image_shape)

    if isinstance(mask_input, dict) and "counts" in mask_input:
        if not PYCOCOTOOLS_AVAILABLE:
            raise ImportError(
                "pycocotools required for RLE. Install: pip install pycocotools"
            )
        rle = dict(mask_input)
        if isinstance(rle.get("counts"), str):
            rle["counts"] = rle["counts"].encode("utf-8")
        return _normalize_binary_mask(mask_utils.decode(rle), image_shape)

    if isinstance(mask_input, list):
        points = _parse_polygon_to_points(mask_input)
        if not points or len(points) < 3:
            return np.zeros((height, width), dtype=np.uint8)
        mask = Image.new("L", (width, height), 0)
        ImageDraw.Draw(mask).polygon(points, outline=255, fill=255)
        return np.array(mask, dtype=np.uint8)

    raise TypeError(f"Unsupported mask type: {type(mask_input)}")

make_scene_glb(*outputs)

Combine multiple GLB meshes into a single scene. Applies layout transforms and a final view correction rotation.

Source code in inference/models/sam3_3d/segment_anything_3d.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def make_scene_glb(*outputs):
    """
    Combine multiple GLB meshes into a single scene.
    Applies layout transforms and a final view correction rotation.
    """
    scene = trimesh.Scene()

    for i, output in enumerate(outputs):
        glb = output["glb"]
        glb = glb.copy()

        glb = transform_glb_to_world(
            glb,
            output["rotation"],
            output["translation"],
            output["scale"],
        )
        scene.add_geometry(glb, node_name=f"object_{i}")

    R_view = np.array([[-1, 0, 0], [0, 0, -1], [0, -1, 0]], dtype=np.float32)
    for geom_name in scene.geometry:
        mesh = scene.geometry[geom_name]
        mesh.vertices = (mesh.vertices.astype(np.float32)) @ R_view
        if (
            hasattr(mesh, "vertex_normals")
            and mesh.vertex_normals is not None
            and len(mesh.vertex_normals) > 0
        ):
            mesh.vertex_normals = (mesh.vertex_normals.astype(np.float32)) @ R_view

    return scene

prepare_individual_object_for_export(gs)

Prepare an individual object Gaussian for PLY export.

Source code in inference/models/sam3_3d/segment_anything_3d.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def prepare_individual_object_for_export(gs):
    """
    Prepare an individual object Gaussian for PLY export.
    """
    from copy import deepcopy

    gs_copy = deepcopy(gs)
    gs_copy = ready_gaussian_for_video_rendering(gs_copy)

    xyz = gs_copy.get_xyz
    device = xyz.device
    dtype = xyz.dtype

    R_view = torch.tensor(
        [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], device=device, dtype=dtype
    )

    new_xyz = xyz @ R_view
    gs_copy.from_xyz(new_xyz)

    q_correction = matrix_to_quaternion(R_view.unsqueeze(0)).squeeze(0)
    old_rotations = gs_copy.get_rotation
    new_rotations = quaternion_multiply(
        q_correction.unsqueeze(0).expand(old_rotations.shape[0], -1), old_rotations
    )
    gs_copy.from_rotation(new_rotations)

    return gs_copy

transform_glb_to_world(glb_mesh, rotation, translation, scale)

Transform a GLB mesh from local to world coordinates.

Source code in inference/models/sam3_3d/segment_anything_3d.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def transform_glb_to_world(glb_mesh, rotation, translation, scale):
    """
    Transform a GLB mesh from local to world coordinates.
    """
    quat = rotation.squeeze()
    quat_normalized = quat / quat.norm()
    R_layout = quaternion_to_matrix(quat_normalized).cpu().numpy()
    t = translation.squeeze().cpu().numpy()
    s = scale.squeeze().cpu().numpy()[0]

    z_to_y_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32)
    y_to_z_up = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32)

    verts = glb_mesh.vertices.copy().astype(np.float32)

    verts = verts @ y_to_z_up

    verts = verts * s
    verts = verts @ R_layout
    verts = verts + t

    verts = verts @ z_to_y_up

    glb_mesh.vertices = verts

    if (
        hasattr(glb_mesh, "vertex_normals")
        and glb_mesh.vertex_normals is not None
        and len(glb_mesh.vertex_normals) > 0
    ):
        normals = glb_mesh.vertex_normals.copy().astype(np.float32)
        normals = normals @ y_to_z_up
        normals = normals @ R_layout
        normals = normals @ z_to_y_up
        glb_mesh.vertex_normals = normals

    return glb_mesh