Skip to content

inference API Reference

core/active_learning

Active learning loop: sampling strategies, data collection middleware, and configuration.

inference.core.active_learning.accounting

Functions

get_images_in_labeling_jobs_of_specific_batch

get_images_in_labeling_jobs_of_specific_batch(
    all_labeling_jobs, batch_id
)

Get the number of images in labeling jobs of a specific batch.

Parameters:

Name Type Description Default
all_labeling_jobs List[dict]

All labeling jobs.

required
batch_id str

ID of the batch.

required

Returns:

Type Description
int

The number of images in labeling jobs of the batch.

Source code in inference/core/active_learning/accounting.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def get_images_in_labeling_jobs_of_specific_batch(
    all_labeling_jobs: List[dict],
    batch_id: str,
) -> int:
    """Get the number of images in labeling jobs of a specific batch.

    Args:
        all_labeling_jobs: All labeling jobs.
        batch_id: ID of the batch.

    Returns:
        The number of images in labeling jobs of the batch.

    """

    matching_jobs = []
    for labeling_job in all_labeling_jobs:
        if batch_id in labeling_job["sourceBatch"]:
            matching_jobs.append(labeling_job)
    return sum(job["numImages"] for job in matching_jobs)

get_matching_labeling_batch

get_matching_labeling_batch(
    all_labeling_batches, batch_name
)

Get the matching labeling batch.

Parameters:

Name Type Description Default
all_labeling_batches List[dict]

All labeling batches.

required
batch_name str

Name of the batch.

required

Returns:

Type Description
Optional[dict]

The matching labeling batch if found, None otherwise.

Source code in inference/core/active_learning/accounting.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def get_matching_labeling_batch(
    all_labeling_batches: List[dict],
    batch_name: str,
) -> Optional[dict]:
    """Get the matching labeling batch.

    Args:
        all_labeling_batches: All labeling batches.
        batch_name: Name of the batch.

    Returns:
        The matching labeling batch if found, None otherwise.

    """
    matching_batch = None
    for labeling_batch in all_labeling_batches:
        if labeling_batch["name"] == batch_name:
            matching_batch = labeling_batch
            break
    return matching_batch

image_can_be_submitted_to_batch

image_can_be_submitted_to_batch(
    batch_name,
    workspace_id,
    dataset_id,
    max_batch_images,
    api_key,
)

Check if an image can be submitted to a batch.

Parameters:

Name Type Description Default
batch_name str

Name of the batch.

required
workspace_id WorkspaceID

ID of the workspace.

required
dataset_id DatasetID

ID of the dataset.

required
max_batch_images Optional[int]

Maximum number of images allowed in the batch.

required
api_key str

API key to use for the request.

required

Returns:

Type Description
bool

True if the image can be submitted to the batch, False otherwise.

Source code in inference/core/active_learning/accounting.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def image_can_be_submitted_to_batch(
    batch_name: str,
    workspace_id: WorkspaceID,
    dataset_id: DatasetID,
    max_batch_images: Optional[int],
    api_key: str,
) -> bool:
    """Check if an image can be submitted to a batch.

    Args:
        batch_name: Name of the batch.
        workspace_id: ID of the workspace.
        dataset_id: ID of the dataset.
        max_batch_images: Maximum number of images allowed in the batch.
        api_key: API key to use for the request.

    Returns:
        True if the image can be submitted to the batch, False otherwise.
    """
    if max_batch_images is None:
        return True
    labeling_batches = get_roboflow_labeling_batches(
        api_key=api_key,
        workspace_id=workspace_id,
        dataset_id=dataset_id,
    )
    matching_labeling_batch = get_matching_labeling_batch(
        all_labeling_batches=labeling_batches["batches"],
        batch_name=batch_name,
    )
    if matching_labeling_batch is None:
        return max_batch_images > 0
    batch_images_under_labeling = 0
    if matching_labeling_batch["numJobs"] > 0:
        labeling_jobs = get_roboflow_labeling_jobs(
            api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
        )
        batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch(
            all_labeling_jobs=labeling_jobs["jobs"],
            batch_id=matching_labeling_batch["id"],
        )
    total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling
    return max_batch_images > total_batch_images

inference.core.active_learning.configuration

Classes

Functions

predictions_incompatible_with_dataset

predictions_incompatible_with_dataset(
    model_type, dataset_type
)

The incompatibility occurs when we mix classification with detection - as detection-based predictions are partially compatible (for instance - for key-points detection we may register bboxes from object detection and manually provide key-points annotations)

Source code in inference/core/active_learning/configuration.py
203
204
205
206
207
208
209
210
211
212
213
214
def predictions_incompatible_with_dataset(
    model_type: str,
    dataset_type: str,
) -> bool:
    """
    The incompatibility occurs when we mix classification with detection - as detection-based
    predictions are partially compatible (for instance - for key-points detection we may register bboxes
    from object detection and manually provide key-points annotations)
    """
    model_is_classifier = CLASSIFICATION_TASK in model_type
    dataset_is_of_type_classification = CLASSIFICATION_TASK in dataset_type
    return model_is_classifier != dataset_is_of_type_classification

core/cache

Caching backends (in-memory, Redis) used for model artefacts and inference results.

inference.core.cache.base

Classes

BaseCache

BaseCache is an abstract base class that defines the interface for a cache.

Source code in inference/core/cache/base.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class BaseCache:
    """
    BaseCache is an abstract base class that defines the interface for a cache.
    """

    def get(self, key: str):
        """
        Gets the value associated with the given key.

        Args:
            key (str): The key to retrieve the value.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def set(self, key: str, value: str, expire: float = None):
        """
        Sets a value for a given key with an optional expire time.

        Args:
            key (str): The key to store the value.
            value (str): The value to store.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def zadd(self, key: str, value: str, score: float, expire: float = None):
        """
        Adds a member with the specified score to the sorted set stored at key.

        Args:
            key (str): The key of the sorted set.
            value (str): The value to add to the sorted set.
            score (float): The score associated with the value.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def zrangebyscore(
        self,
        key: str,
        min: Optional[float] = -1,
        max: Optional[float] = float("inf"),
        withscores: bool = False,
    ):
        """
        Retrieves a range of members from a sorted set.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The starting index of the range. Defaults to -1.
            stop (int, optional): The ending index of the range. Defaults to float("inf").
            withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def zremrangebyscore(
        self,
        key: str,
        start: Optional[int] = -1,
        stop: Optional[int] = float("inf"),
    ):
        """
        Removes all members in a sorted set within the given scores.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The minimum score of the range. Defaults to -1.
            stop (int, optional): The maximum score of the range. Defaults to float("inf").

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def acquire_lock(self, key: str, expire: float = None) -> Any:
        raise NotImplementedError()

    @contextmanager
    def lock(self, key: str, expire: float = None) -> Any:
        logger.debug(f"Acquiring lock at cache key: {key}")
        l = self.acquire_lock(key, expire=expire)
        try:
            yield l
        finally:
            logger.debug(f"Releasing lock at cache key: {key}")
            try:
                l.release()
            except LockNotOwnedError:
                # Lock TTL expired before release - this is expected in some cases
                logger.warning(f"Lock at cache key {key} expired before release")

    def set_numpy(self, key: str, value: Any, expire: float = None):
        """
        Caches a numpy array.

        Args:
            key (str): The key to store the value.
            value (Any): The value to store.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()

    def get_numpy(self, key: str) -> Any:
        """
        Retrieves a numpy array from the cache.

        Args:
            key (str): The key of the value to retrieve.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError()
Functions
get
get(key)

Gets the value associated with the given key.

Parameters:

Name Type Description Default
key str

The key to retrieve the value.

required

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
14
15
16
17
18
19
20
21
22
23
24
def get(self, key: str):
    """
    Gets the value associated with the given key.

    Args:
        key (str): The key to retrieve the value.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
get_numpy
get_numpy(key)

Retrieves a numpy array from the cache.

Parameters:

Name Type Description Default
key str

The key of the value to retrieve.

required

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
126
127
128
129
130
131
132
133
134
135
136
def get_numpy(self, key: str) -> Any:
    """
    Retrieves a numpy array from the cache.

    Args:
        key (str): The key of the value to retrieve.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
set
set(key, value, expire=None)

Sets a value for a given key with an optional expire time.

Parameters:

Name Type Description Default
key str

The key to store the value.

required
value str

The value to store.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
26
27
28
29
30
31
32
33
34
35
36
37
38
def set(self, key: str, value: str, expire: float = None):
    """
    Sets a value for a given key with an optional expire time.

    Args:
        key (str): The key to store the value.
        value (str): The value to store.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
set_numpy
set_numpy(key, value, expire=None)

Caches a numpy array.

Parameters:

Name Type Description Default
key str

The key to store the value.

required
value Any

The value to store.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
112
113
114
115
116
117
118
119
120
121
122
123
124
def set_numpy(self, key: str, value: Any, expire: float = None):
    """
    Caches a numpy array.

    Args:
        key (str): The key to store the value.
        value (Any): The value to store.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
zadd
zadd(key, value, score, expire=None)

Adds a member with the specified score to the sorted set stored at key.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
value str

The value to add to the sorted set.

required
score float

The score associated with the value.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def zadd(self, key: str, value: str, score: float, expire: float = None):
    """
    Adds a member with the specified score to the sorted set stored at key.

    Args:
        key (str): The key of the sorted set.
        value (str): The value to add to the sorted set.
        score (float): The score associated with the value.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
zrangebyscore
zrangebyscore(
    key, min=-1, max=float("inf"), withscores=False
)

Retrieves a range of members from a sorted set.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The starting index of the range. Defaults to -1.

required
stop int

The ending index of the range. Defaults to float("inf").

required
withscores bool

Whether to return the scores along with the values. Defaults to False.

False

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def zrangebyscore(
    self,
    key: str,
    min: Optional[float] = -1,
    max: Optional[float] = float("inf"),
    withscores: bool = False,
):
    """
    Retrieves a range of members from a sorted set.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The starting index of the range. Defaults to -1.
        stop (int, optional): The ending index of the range. Defaults to float("inf").
        withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()
zremrangebyscore
zremrangebyscore(key, start=-1, stop=float('inf'))

Removes all members in a sorted set within the given scores.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The minimum score of the range. Defaults to -1.

-1
stop int

The maximum score of the range. Defaults to float("inf").

float('inf')

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in inference/core/cache/base.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def zremrangebyscore(
    self,
    key: str,
    start: Optional[int] = -1,
    stop: Optional[int] = float("inf"),
):
    """
    Removes all members in a sorted set within the given scores.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The minimum score of the range. Defaults to -1.
        stop (int, optional): The maximum score of the range. Defaults to float("inf").

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError()

inference.core.cache.memory

Classes

MemoryCache

Bases: BaseCache

MemoryCache is an in-memory cache that implements the BaseCache interface.

Attributes:

Name Type Description
cache dict

A dictionary to store the cache values.

expires dict

A dictionary to store the expiration times of the cache values.

zexpires dict

A dictionary to store the expiration times of the sorted set values.

_expire_thread Thread

A thread that runs the _expire method.

Source code in inference/core/cache/memory.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class MemoryCache(BaseCache):
    """
    MemoryCache is an in-memory cache that implements the BaseCache interface.

    Attributes:
        cache (dict): A dictionary to store the cache values.
        expires (dict): A dictionary to store the expiration times of the cache values.
        zexpires (dict): A dictionary to store the expiration times of the sorted set values.
        _expire_thread (threading.Thread): A thread that runs the _expire method.
    """

    def __init__(self) -> None:
        """
        Initializes a new instance of the MemoryCache class.
        """
        self.cache = dict()
        self.expires = dict()
        self.zexpires = dict()

        self._expire_thread = threading.Thread(target=self._expire)
        self._expire_thread.daemon = True
        self._expire_thread.start()

    def _expire(self):
        """
        Removes the expired keys from the cache and zexpires dictionaries.

        This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration.
        """
        while True:
            now = time.time()
            keys_to_delete = []
            for k, v in self.expires.copy().items():
                if v < now:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del self.cache[k]
                del self.expires[k]
            keys_to_delete = []
            for k, v in self.zexpires.copy().items():
                if v < now:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del self.cache[k[0]][k[1]]
                del self.zexpires[k]
            while time.time() - now < MEMORY_CACHE_EXPIRE_INTERVAL:
                time.sleep(0.1)

    def get(self, key: str):
        """
        Gets the value associated with the given key.

        Args:
            key (str): The key to retrieve the value.

        Returns:
            str: The value associated with the key, or None if the key does not exist or is expired.
        """
        if key in self.expires:
            if self.expires[key] < time.time():
                del self.cache[key]
                del self.expires[key]
                return None
        return self.cache.get(key)

    def set(self, key: str, value: str, expire: float = None):
        """
        Sets a value for a given key with an optional expire time.

        Args:
            key (str): The key to store the value.
            value (str): The value to store.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
        """
        self.cache[key] = value
        if expire:
            self.expires[key] = expire + time.time()

    def zadd(self, key: str, value: Any, score: float, expire: float = None):
        """
        Adds a member with the specified score to the sorted set stored at key.

        Args:
            key (str): The key of the sorted set.
            value (str): The value to add to the sorted set.
            score (float): The score associated with the value.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
        """
        if not key in self.cache:
            self.cache[key] = dict()
        self.cache[key][score] = value
        if expire:
            self.zexpires[(key, score)] = expire + time.time()

    def zrangebyscore(
        self,
        key: str,
        min: Optional[float] = -1,
        max: Optional[float] = float("inf"),
        withscores: bool = False,
    ):
        """
        Retrieves a range of members from a sorted set.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The starting score of the range. Defaults to -1.
            stop (int, optional): The ending score of the range. Defaults to float("inf").
            withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

        Returns:
            list: A list of values (or value-score pairs if withscores is True) in the specified score range.
        """
        if not key in self.cache:
            return []
        keys = sorted([k for k in self.cache[key].keys() if min <= k <= max])
        if withscores:
            return [(self.cache[key][k], k) for k in keys]
        else:
            return [self.cache[key][k] for k in keys]

    def zremrangebyscore(
        self,
        key: str,
        min: Optional[float] = -1,
        max: Optional[float] = float("inf"),
    ):
        """
        Removes all members in a sorted set within the given scores.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The minimum score of the range. Defaults to -1.
            stop (int, optional): The maximum score of the range. Defaults to float("inf").

        Returns:
            int: The number of members removed from the sorted set.
        """
        res = self.zrangebyscore(key, min=min, max=max, withscores=True)
        keys_to_delete = [k[1] for k in res]
        for k in keys_to_delete:
            del self.cache[key][k]
        return len(keys_to_delete)

    def acquire_lock(self, key: str, expire=None) -> Any:
        lock: Optional[Lock] = self.get(key)
        if lock is None:
            lock = Lock()
            self.set(key, lock, expire=expire)
        if expire is None:
            expire = -1
        acquired = lock.acquire(timeout=expire)
        if not acquired:
            raise TimeoutError()
        # refresh the lock
        self.set(key, lock, expire=expire)
        return lock

    def set_numpy(self, key: str, value: Any, expire: float = None):
        return self.set(key, value, expire=expire)

    def get_numpy(self, key: str):
        return self.get(key)
Functions
__init__
__init__()

Initializes a new instance of the MemoryCache class.

Source code in inference/core/cache/memory.py
21
22
23
24
25
26
27
28
29
30
31
def __init__(self) -> None:
    """
    Initializes a new instance of the MemoryCache class.
    """
    self.cache = dict()
    self.expires = dict()
    self.zexpires = dict()

    self._expire_thread = threading.Thread(target=self._expire)
    self._expire_thread.daemon = True
    self._expire_thread.start()
get
get(key)

Gets the value associated with the given key.

Parameters:

Name Type Description Default
key str

The key to retrieve the value.

required

Returns:

Name Type Description
str

The value associated with the key, or None if the key does not exist or is expired.

Source code in inference/core/cache/memory.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get(self, key: str):
    """
    Gets the value associated with the given key.

    Args:
        key (str): The key to retrieve the value.

    Returns:
        str: The value associated with the key, or None if the key does not exist or is expired.
    """
    if key in self.expires:
        if self.expires[key] < time.time():
            del self.cache[key]
            del self.expires[key]
            return None
    return self.cache.get(key)
set
set(key, value, expire=None)

Sets a value for a given key with an optional expire time.

Parameters:

Name Type Description Default
key str

The key to store the value.

required
value str

The value to store.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None
Source code in inference/core/cache/memory.py
75
76
77
78
79
80
81
82
83
84
85
86
def set(self, key: str, value: str, expire: float = None):
    """
    Sets a value for a given key with an optional expire time.

    Args:
        key (str): The key to store the value.
        value (str): The value to store.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
    """
    self.cache[key] = value
    if expire:
        self.expires[key] = expire + time.time()
zadd
zadd(key, value, score, expire=None)

Adds a member with the specified score to the sorted set stored at key.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
value str

The value to add to the sorted set.

required
score float

The score associated with the value.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None
Source code in inference/core/cache/memory.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def zadd(self, key: str, value: Any, score: float, expire: float = None):
    """
    Adds a member with the specified score to the sorted set stored at key.

    Args:
        key (str): The key of the sorted set.
        value (str): The value to add to the sorted set.
        score (float): The score associated with the value.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
    """
    if not key in self.cache:
        self.cache[key] = dict()
    self.cache[key][score] = value
    if expire:
        self.zexpires[(key, score)] = expire + time.time()
zrangebyscore
zrangebyscore(
    key, min=-1, max=float("inf"), withscores=False
)

Retrieves a range of members from a sorted set.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The starting score of the range. Defaults to -1.

required
stop int

The ending score of the range. Defaults to float("inf").

required
withscores bool

Whether to return the scores along with the values. Defaults to False.

False

Returns:

Name Type Description
list

A list of values (or value-score pairs if withscores is True) in the specified score range.

Source code in inference/core/cache/memory.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def zrangebyscore(
    self,
    key: str,
    min: Optional[float] = -1,
    max: Optional[float] = float("inf"),
    withscores: bool = False,
):
    """
    Retrieves a range of members from a sorted set.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The starting score of the range. Defaults to -1.
        stop (int, optional): The ending score of the range. Defaults to float("inf").
        withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

    Returns:
        list: A list of values (or value-score pairs if withscores is True) in the specified score range.
    """
    if not key in self.cache:
        return []
    keys = sorted([k for k in self.cache[key].keys() if min <= k <= max])
    if withscores:
        return [(self.cache[key][k], k) for k in keys]
    else:
        return [self.cache[key][k] for k in keys]
zremrangebyscore
zremrangebyscore(key, min=-1, max=float('inf'))

Removes all members in a sorted set within the given scores.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The minimum score of the range. Defaults to -1.

required
stop int

The maximum score of the range. Defaults to float("inf").

required

Returns:

Name Type Description
int

The number of members removed from the sorted set.

Source code in inference/core/cache/memory.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def zremrangebyscore(
    self,
    key: str,
    min: Optional[float] = -1,
    max: Optional[float] = float("inf"),
):
    """
    Removes all members in a sorted set within the given scores.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The minimum score of the range. Defaults to -1.
        stop (int, optional): The maximum score of the range. Defaults to float("inf").

    Returns:
        int: The number of members removed from the sorted set.
    """
    res = self.zrangebyscore(key, min=min, max=max, withscores=True)
    keys_to_delete = [k[1] for k in res]
    for k in keys_to_delete:
        del self.cache[key][k]
    return len(keys_to_delete)

inference.core.cache.model_artifacts

Functions

clear_cache

clear_cache(model_id=None, delete_from_disk=True)

Clear the cache for a specific model or the entire cache directory.

Parameters:

Name Type Description Default
model_id Optional[str]

The model ID to clear cache for. If None, clears entire cache. Defaults to None.

None
delete_from_disk bool

Whether to delete cached files from disk. Defaults to False.

True
Source code in inference/core/cache/model_artifacts.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def clear_cache(model_id: Optional[str] = None, delete_from_disk: bool = True) -> None:
    """Clear the cache for a specific model or the entire cache directory.

    Args:
        model_id (Optional[str], optional): The model ID to clear cache for. If None, clears entire cache. Defaults to None.
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to False.
    """
    if not delete_from_disk:
        return
    cache_dir = get_cache_dir(model_id=model_id)
    if not os.path.exists(cache_dir):
        return
    lock_dir = MODEL_CACHE_DIR + "/_file_locks"  # Dedicated lock directory
    os.makedirs(lock_dir, exist_ok=True)  # ensure lock directory exists.

    # Use the last 2 levels of the cache directory path as the lock file name suffix
    parts = os.path.normpath(cache_dir).split(os.sep)
    suffix = (
        os.path.join(*parts[-2:]) if len(parts) >= 2 else os.path.basename(cache_dir)
    )
    lock_file = os.path.join(lock_dir, f"{suffix}.lock")

    try:
        lock = FileLock(lock_file, timeout=10)  # 10 second timeout
        with lock:
            if not os.path.exists(cache_dir):  # Check again after acquiring lock
                return  # Already deleted by another process

            max_retries = 3
            retry_delay = 1  # Initial delay in seconds

            for attempt in range(max_retries):
                try:
                    shutil.rmtree(cache_dir, onerror=_rmtree_onerror)
                    return  # Success
                except FileNotFoundError:
                    return  # Already deleted by another process
                except Exception as e:
                    if attempt < max_retries - 1:
                        logger.warning(
                            f"Error deleting cache %s: %s, retrying in %s seconds...",
                            cache_dir,
                            e,
                            retry_delay,
                        )
                        time.sleep(retry_delay)
                        retry_delay *= 2  # Exponential backoff
                    else:
                        logger.warning(
                            f"Error deleting cache %s: %s, max retries exceeded.",
                            cache_dir,
                            e,
                        )
                        return
    except Exception as e:
        logger.warning(
            f"Error acquiring lock for cache %s, skipping cache cleanup. %s",
            cache_dir,
            e,
        )

inference.core.cache.redis

Classes

RedisCache

Bases: BaseCache

MemoryCache is an in-memory cache that implements the BaseCache interface.

Attributes:

Name Type Description
cache dict

A dictionary to store the cache values.

expires dict

A dictionary to store the expiration times of the cache values.

zexpires dict

A dictionary to store the expiration times of the sorted set values.

_expire_thread Thread

A thread that runs the _expire method.

Source code in inference/core/cache/redis.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class RedisCache(BaseCache):
    """
    MemoryCache is an in-memory cache that implements the BaseCache interface.

    Attributes:
        cache (dict): A dictionary to store the cache values.
        expires (dict): A dictionary to store the expiration times of the cache values.
        zexpires (dict): A dictionary to store the expiration times of the sorted set values.
        _expire_thread (threading.Thread): A thread that runs the _expire method.
    """

    def __init__(
        self,
        host: str = "localhost",
        port: int = 6379,
        db: int = 0,
        ssl: bool = False,
        timeout: float = 2.0,
    ) -> None:
        """
        Initializes a new instance of the MemoryCache class.
        """
        self.client = redis.Redis(
            host=host,
            port=port,
            db=db,
            decode_responses=False,
            ssl=ssl,
            socket_timeout=timeout,
            socket_connect_timeout=timeout,
        )
        logger.debug("Attempting to diagnose Redis connection...")
        self.client.ping()
        logger.debug("Redis connection established.")
        self.zexpires = dict()

        self._expire_thread = threading.Thread(target=self._expire, daemon=True)
        self._expire_thread.start()

    def _expire(self):
        """
        Removes the expired keys from the cache and zexpires dictionaries.

        This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration.
        """
        while True:
            now = time.time()
            for k, v in copy(list(self.zexpires.items())):
                if v < now:
                    tolerance_factor = 1e-14  # floating point accuracy
                    self.zremrangebyscore(
                        k[0], k[1] - tolerance_factor, k[1] + tolerance_factor
                    )
                    del self.zexpires[k]
            sleep_time = MEMORY_CACHE_EXPIRE_INTERVAL - (time.time() - now)
            time.sleep(max(sleep_time, 0))

    def get(self, key: str):
        """
        Gets the value associated with the given key.

        Args:
            key (str): The key to retrieve the value.

        Returns:
            str: The value associated with the key, or None if the key does not exist or is expired.
        """
        item = self.client.get(key)
        if item is not None:
            try:
                return json.loads(item)
            except (TypeError, ValueError):
                return item

    def set(self, key: str, value: str, expire: float = None):
        """
        Sets a value for a given key with an optional expire time.

        Args:
            key (str): The key to store the value.
            value (str): The value to store.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
        """
        if not isinstance(value, bytes):
            value = json.dumps(value)
        self.client.set(key, value, ex=expire)

    def zadd(self, key: str, value: Any, score: float, expire: float = None):
        """
        Adds a member with the specified score to the sorted set stored at key.

        Args:
            key (str): The key of the sorted set.
            value (str): The value to add to the sorted set.
            score (float): The score associated with the value.
            expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
        """
        # serializable_value = self.ensure_serializable(value)
        value = json.dumps(value)
        self.client.zadd(key, {value: score})
        if expire:
            self.zexpires[(key, score)] = expire + time.time()

    def zrangebyscore(
        self,
        key: str,
        min: Optional[float] = -1,
        max: Optional[float] = float("inf"),
        withscores: bool = False,
    ):
        """
        Retrieves a range of members from a sorted set.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The starting score of the range. Defaults to -1.
            stop (int, optional): The ending score of the range. Defaults to float("inf").
            withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

        Returns:
            list: A list of values (or value-score pairs if withscores is True) in the specified score range.
        """
        res = self.client.zrangebyscore(key, min, max, withscores=withscores)
        if withscores:
            return [(json.loads(x), y) for x, y in res]
        else:
            return [json.loads(x) for x in res]

    def zremrangebyscore(
        self,
        key: str,
        min: Optional[float] = -1,
        max: Optional[float] = float("inf"),
    ):
        """
        Removes all members in a sorted set within the given scores.

        Args:
            key (str): The key of the sorted set.
            start (int, optional): The minimum score of the range. Defaults to -1.
            stop (int, optional): The maximum score of the range. Defaults to float("inf").

        Returns:
            int: The number of members removed from the sorted set.
        """
        return self.client.zremrangebyscore(key, min, max)

    def ensure_serializable(self, value: Any):
        if isinstance(value, dict):
            for k, v in value.items():
                if isinstance(v, Exception):
                    value[k] = str(v)
                elif inspect.isclass(v) and isinstance(v, InferenceResponseImage):
                    value[k] = v.dict()
        return value

    def acquire_lock(self, key: str, expire=None) -> Any:
        l = self.client.lock(key, blocking=True, timeout=expire)
        acquired = l.acquire(blocking_timeout=expire)
        if not acquired:
            raise TimeoutError("Couldn't get lock")
        # refresh the lock
        if expire is not None:
            l.extend(expire)
        return l

    def set_numpy(self, key: str, value: Any, expire: float = None):
        serialized_value = pickle.dumps(value)
        self.set(key, serialized_value, expire=expire)

    def get_numpy(self, key: str) -> Any:
        serialized_value = self.get(key)
        if serialized_value is not None:
            return pickle.loads(serialized_value)
        else:
            return None
Functions
__init__
__init__(
    host="localhost",
    port=6379,
    db=0,
    ssl=False,
    timeout=2.0,
)

Initializes a new instance of the MemoryCache class.

Source code in inference/core/cache/redis.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    host: str = "localhost",
    port: int = 6379,
    db: int = 0,
    ssl: bool = False,
    timeout: float = 2.0,
) -> None:
    """
    Initializes a new instance of the MemoryCache class.
    """
    self.client = redis.Redis(
        host=host,
        port=port,
        db=db,
        decode_responses=False,
        ssl=ssl,
        socket_timeout=timeout,
        socket_connect_timeout=timeout,
    )
    logger.debug("Attempting to diagnose Redis connection...")
    self.client.ping()
    logger.debug("Redis connection established.")
    self.zexpires = dict()

    self._expire_thread = threading.Thread(target=self._expire, daemon=True)
    self._expire_thread.start()
get
get(key)

Gets the value associated with the given key.

Parameters:

Name Type Description Default
key str

The key to retrieve the value.

required

Returns:

Name Type Description
str

The value associated with the key, or None if the key does not exist or is expired.

Source code in inference/core/cache/redis.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get(self, key: str):
    """
    Gets the value associated with the given key.

    Args:
        key (str): The key to retrieve the value.

    Returns:
        str: The value associated with the key, or None if the key does not exist or is expired.
    """
    item = self.client.get(key)
    if item is not None:
        try:
            return json.loads(item)
        except (TypeError, ValueError):
            return item
set
set(key, value, expire=None)

Sets a value for a given key with an optional expire time.

Parameters:

Name Type Description Default
key str

The key to store the value.

required
value str

The value to store.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None
Source code in inference/core/cache/redis.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def set(self, key: str, value: str, expire: float = None):
    """
    Sets a value for a given key with an optional expire time.

    Args:
        key (str): The key to store the value.
        value (str): The value to store.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
    """
    if not isinstance(value, bytes):
        value = json.dumps(value)
    self.client.set(key, value, ex=expire)
zadd
zadd(key, value, score, expire=None)

Adds a member with the specified score to the sorted set stored at key.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
value str

The value to add to the sorted set.

required
score float

The score associated with the value.

required
expire float

The time, in seconds, after which the key will expire. Defaults to None.

None
Source code in inference/core/cache/redis.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def zadd(self, key: str, value: Any, score: float, expire: float = None):
    """
    Adds a member with the specified score to the sorted set stored at key.

    Args:
        key (str): The key of the sorted set.
        value (str): The value to add to the sorted set.
        score (float): The score associated with the value.
        expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
    """
    # serializable_value = self.ensure_serializable(value)
    value = json.dumps(value)
    self.client.zadd(key, {value: score})
    if expire:
        self.zexpires[(key, score)] = expire + time.time()
zrangebyscore
zrangebyscore(
    key, min=-1, max=float("inf"), withscores=False
)

Retrieves a range of members from a sorted set.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The starting score of the range. Defaults to -1.

required
stop int

The ending score of the range. Defaults to float("inf").

required
withscores bool

Whether to return the scores along with the values. Defaults to False.

False

Returns:

Name Type Description
list

A list of values (or value-score pairs if withscores is True) in the specified score range.

Source code in inference/core/cache/redis.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def zrangebyscore(
    self,
    key: str,
    min: Optional[float] = -1,
    max: Optional[float] = float("inf"),
    withscores: bool = False,
):
    """
    Retrieves a range of members from a sorted set.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The starting score of the range. Defaults to -1.
        stop (int, optional): The ending score of the range. Defaults to float("inf").
        withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.

    Returns:
        list: A list of values (or value-score pairs if withscores is True) in the specified score range.
    """
    res = self.client.zrangebyscore(key, min, max, withscores=withscores)
    if withscores:
        return [(json.loads(x), y) for x, y in res]
    else:
        return [json.loads(x) for x in res]
zremrangebyscore
zremrangebyscore(key, min=-1, max=float('inf'))

Removes all members in a sorted set within the given scores.

Parameters:

Name Type Description Default
key str

The key of the sorted set.

required
start int

The minimum score of the range. Defaults to -1.

required
stop int

The maximum score of the range. Defaults to float("inf").

required

Returns:

Name Type Description
int

The number of members removed from the sorted set.

Source code in inference/core/cache/redis.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def zremrangebyscore(
    self,
    key: str,
    min: Optional[float] = -1,
    max: Optional[float] = float("inf"),
):
    """
    Removes all members in a sorted set within the given scores.

    Args:
        key (str): The key of the sorted set.
        start (int, optional): The minimum score of the range. Defaults to -1.
        stop (int, optional): The maximum score of the range. Defaults to float("inf").

    Returns:
        int: The number of members removed from the sorted set.
    """
    return self.client.zremrangebyscore(key, min, max)

core/devices

Hardware device detection and selection helpers.

inference.core.devices.utils

Functions

get_cpu_id

get_cpu_id()

Fetches the CPU ID based on the operating system.

Attempts to get the CPU ID for Windows, Linux, and MacOS. In case of any error or an unsupported OS, returns None.

Returns:

Type Description

Optional[str]: CPU ID string if available, None otherwise.

Source code in inference/core/devices/utils.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def get_cpu_id():
    """Fetches the CPU ID based on the operating system.

    Attempts to get the CPU ID for Windows, Linux, and MacOS.
    In case of any error or an unsupported OS, returns None.

    Returns:
        Optional[str]: CPU ID string if available, None otherwise.
    """
    try:
        if platform.system() == "Windows":
            return os.popen("wmic cpu get ProcessorId").read().strip()
        elif platform.system() == "Linux":
            return (
                open("/proc/cpuinfo").read().split("processor")[0].split(":")[1].strip()
            )
        elif platform.system() == "Darwin":
            import subprocess

            return (
                subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
                .strip()
                .decode()
            )
    except Exception as e:
        return None

get_device_hostname

get_device_hostname()

Fetches the device's hostname.

Returns:

Name Type Description
str

The device's hostname.

Source code in inference/core/devices/utils.py
107
108
109
110
111
112
113
def get_device_hostname():
    """Fetches the device's hostname.

    Returns:
        str: The device's hostname.
    """
    return platform.node()

get_gpu_id

get_gpu_id()

Fetches the GPU ID if a GPU is present.

Tries to import and use the pynvml (delivered by nvidia-ml-py) module to retrieve the GPU information.

Returns:

Type Description

Optional[int]: GPU ID if available, None otherwise.

Source code in inference/core/devices/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def get_gpu_id():
    """Fetches the GPU ID if a GPU is present.

    Tries to import and use the `pynvml` (delivered by nvidia-ml-py) module to retrieve the GPU information.

    Returns:
        Optional[int]: GPU ID if available, None otherwise.
    """
    try:
        from pynvml import nvmlDeviceGetCount, nvmlInit

        nvmlInit()
        gpus_count = nvmlDeviceGetCount()
        if gpus_count:
            return 0
    except ImportError:
        return None
    except Exception:
        return None

get_inference_server_id

get_inference_server_id()

Fetches a unique device ID.

Tries to get the GPU ID first, then falls back to CPU ID. If the application is running inside Docker, the Docker container ID is appended to the hostname.

Returns:

Name Type Description
str

A unique string representing the device. If unable to determine, returns "UNKNOWN".

Source code in inference/core/devices/utils.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def get_inference_server_id():
    """Fetches a unique device ID.

    Tries to get the GPU ID first, then falls back to CPU ID.
    If the application is running inside Docker, the Docker container ID is appended to the hostname.

    Returns:
        str: A unique string representing the device. If unable to determine, returns "UNKNOWN".
    """
    try:
        if INFERENCE_SERVER_ID is not None:
            return INFERENCE_SERVER_ID
        id = random_string(6)
        gpu_id = get_gpu_id()
        if gpu_id is not None:
            return f"{id}-GPU-{gpu_id}"
        jetson_id = get_jetson_id()
        if jetson_id is not None:
            return f"{id}-JETSON-{jetson_id}"
        return id
    except Exception as e:
        return "UNKNOWN"

get_jetson_id

get_jetson_id()

Fetches the Jetson device's serial number.

Attempts to read the serial number from the device tree. In case of any error, returns None.

Returns:

Type Description

Optional[str]: Jetson device serial number if available, None otherwise.

Source code in inference/core/devices/utils.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_jetson_id():
    """Fetches the Jetson device's serial number.

    Attempts to read the serial number from the device tree.
    In case of any error, returns None.

    Returns:
        Optional[str]: Jetson device serial number if available, None otherwise.
    """
    try:
        # Fetch the device's serial number
        if not os.path.exists("/proc/device-tree/serial-number"):
            return None
        serial_number = os.popen("cat /proc/device-tree/serial-number").read().strip()
        if serial_number == "":
            return None
        return serial_number
    except Exception as e:
        return None

is_running_in_docker

is_running_in_docker()

Checks if the current process is running inside a Docker container.

Returns:

Name Type Description
bool

True if running inside a Docker container, False otherwise.

Source code in inference/core/devices/utils.py
10
11
12
13
14
15
16
def is_running_in_docker():
    """Checks if the current process is running inside a Docker container.

    Returns:
        bool: True if running inside a Docker container, False otherwise.
    """
    return os.path.exists("/.dockerenv")

core/entities/requests

inference.core.entities.requests.clip

Classes

ClipCompareRequest

Bases: ClipInferenceRequest

Request for CLIP comparison.

Attributes:

Name Type Description
subject Union[InferenceRequestImage, str]

The type of image data provided, one of 'url' or 'base64'.

subject_type str

The type of subject, one of 'image' or 'text'.

prompt Union[List[InferenceRequestImage], InferenceRequestImage, str, List[str], Dict[str, Union[InferenceRequestImage, str]]]

The prompt for comparison.

prompt_type str

The type of prompt, one of 'image' or 'text'.

Source code in inference/core/entities/requests/clip.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class ClipCompareRequest(ClipInferenceRequest):
    """Request for CLIP comparison.

    Attributes:
        subject (Union[InferenceRequestImage, str]): The type of image data provided, one of 'url' or 'base64'.
        subject_type (str): The type of subject, one of 'image' or 'text'.
        prompt (Union[List[InferenceRequestImage], InferenceRequestImage, str, List[str], Dict[str, Union[InferenceRequestImage, str]]]): The prompt for comparison.
        prompt_type (str): The type of prompt, one of 'image' or 'text'.
    """

    subject: Union[InferenceRequestImage, str] = Field(
        examples=["url"],
        description="The type of image data provided, one of 'url' or 'base64'",
    )
    subject_type: str = Field(
        default="image",
        examples=["image"],
        description="The type of subject, one of 'image' or 'text'",
    )
    prompt: Union[
        List[InferenceRequestImage],
        InferenceRequestImage,
        str,
        List[str],
        Dict[str, Union[InferenceRequestImage, str]],
    ]
    prompt_type: str = Field(
        default="text",
        examples=["text"],
        description="The type of prompt, one of 'image' or 'text'",
    )

ClipImageEmbeddingRequest

Bases: ClipInferenceRequest

Request for CLIP image embedding.

Attributes:

Name Type Description
image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) to be embedded.

Source code in inference/core/entities/requests/clip.py
38
39
40
41
42
43
44
45
class ClipImageEmbeddingRequest(ClipInferenceRequest):
    """Request for CLIP image embedding.

    Attributes:
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) to be embedded.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]

ClipInferenceRequest

Bases: BaseRequest

Request for CLIP inference.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

clip_version_id Optional[str]

The version ID of CLIP to be used for this request.

Source code in inference/core/entities/requests/clip.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ClipInferenceRequest(BaseRequest):
    """Request for CLIP inference.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        clip_version_id (Optional[str]): The version ID of CLIP to be used for this request.
    """

    clip_version_id: Optional[str] = Field(
        default=CLIP_VERSION_ID,
        examples=["ViT-B-16"],
        description="The version ID of CLIP to be used for this request. Must be one of RN101, RN50, RN50x16, RN50x4, RN50x64, ViT-B-16, ViT-B-32, ViT-L-14-336px, and ViT-L-14.",
    )
    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("clip_version_id") is None:
            return None
        return f"clip/{values['clip_version_id']}"

ClipTextEmbeddingRequest

Bases: ClipInferenceRequest

Request for CLIP text embedding.

Attributes:

Name Type Description
text Union[List[str], str]

A string or list of strings.

Source code in inference/core/entities/requests/clip.py
48
49
50
51
52
53
54
55
56
57
58
class ClipTextEmbeddingRequest(ClipInferenceRequest):
    """Request for CLIP text embedding.

    Attributes:
        text (Union[List[str], str]): A string or list of strings.
    """

    text: Union[List[str], str] = Field(
        examples=["The quick brown fox jumps over the lazy dog"],
        description="A string or list of strings",
    )

inference.core.entities.requests.doctr

Classes

DoctrOCRInferenceRequest

Bases: BaseRequest

DocTR inference request.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

Source code in inference/core/entities/requests/doctr.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class DoctrOCRInferenceRequest(BaseRequest):
    """
    DocTR inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    doctr_version_id: Optional[str] = "default"
    model_id: Optional[str] = Field(None)
    # flag to generate bounding box data rather than just a string, set to False for backwards compatibility
    generate_bounding_boxes: Optional[bool] = False

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True, allow_reuse=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("doctr_version_id") is None:
            return None
        return f"doctr/{values['doctr_version_id']}"

inference.core.entities.requests.dynamic_class_base

Classes

DynamicClassBaseInferenceRequest

Bases: CVInferenceRequest

Request for zero-shot object detection models (with dynamic class lists).

Attributes:

Name Type Description
text List[str]

A list of strings.

Source code in inference/core/entities/requests/dynamic_class_base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
class DynamicClassBaseInferenceRequest(CVInferenceRequest):
    """Request for zero-shot object detection models (with dynamic class lists).

    Attributes:
        text (List[str]): A list of strings.
    """

    model_id: Optional[str] = Field(None)
    text: List[str] = Field(
        examples=[["person", "dog", "cat"]],
        description="A list of strings",
    )

inference.core.entities.requests.easy_ocr

Classes

EasyOCRInferenceRequest

Bases: BaseRequest

EasyOCR inference request.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

Source code in inference/core/entities/requests/easy_ocr.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class EasyOCRInferenceRequest(BaseRequest):
    """
    EasyOCR inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    easy_ocr_version_id: Optional[str] = EASYOCR_VERSION_ID
    model_id: Optional[str] = Field(None)
    language_codes: Optional[List[str]] = Field(default=["en"])
    quantize: Optional[bool] = Field(
        default=False,
        description="Quantized models are smaller and faster, but may be less accurate and won't work correctly on all hardware.",
    )

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True, allow_reuse=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("easy_ocr_version_id") is None:
            return None
        return f"easy_ocr/{values['easy_ocr_version_id']}"

inference.core.entities.requests.gaze

Classes

GazeDetectionInferenceRequest

Bases: BaseRequest

Request for gaze detection inference.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

gaze_version_id Optional[str]

The version ID of Gaze to be used for this request.

do_run_face_detection Optional[bool]

If true, face detection will be applied; if false, face detection will be ignored and the whole input image will be used for gaze detection.

image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) for inference.

Source code in inference/core/entities/requests/gaze.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class GazeDetectionInferenceRequest(BaseRequest):
    """Request for gaze detection inference.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        gaze_version_id (Optional[str]): The version ID of Gaze to be used for this request.
        do_run_face_detection (Optional[bool]): If true, face detection will be applied; if false, face detection will be ignored and the whole input image will be used for gaze detection.
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference.
    """

    gaze_version_id: Optional[str] = Field(
        default=GAZE_VERSION_ID,
        examples=["L2CS"],
        description="The version ID of Gaze to be used for this request. Must be one of l2cs.",
    )

    do_run_face_detection: Optional[bool] = Field(
        default=True,
        examples=[False],
        description="If true, face detection will be applied; if false, face detection will be ignored and the whole input image will be used for gaze detection",
    )

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True, allow_reuse=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("gaze_version_id") is None:
            return None
        return f"gaze/{values['gaze_version_id']}"

inference.core.entities.requests.groundingdino

Classes

GroundingDINOInferenceRequest

Bases: DynamicClassBaseInferenceRequest

Request for Grounding DINO zero-shot predictions.

Attributes:

Name Type Description
text List[str]

A list of strings.

Source code in inference/core/entities/requests/groundingdino.py
 9
10
11
12
13
14
15
16
17
18
19
class GroundingDINOInferenceRequest(DynamicClassBaseInferenceRequest):
    """Request for Grounding DINO zero-shot predictions.

    Attributes:
        text (List[str]): A list of strings.
    """

    box_threshold: Optional[float] = 0.5
    grounding_dino_version_id: Optional[str] = "default"
    text_threshold: Optional[float] = 0.5
    class_agnostic_nms: Optional[bool] = CLASS_AGNOSTIC_NMS

inference.core.entities.requests.inference

Classes

BaseRequest

Bases: BaseModel

Base request for inference.

Attributes:

Name Type Description
id str_

A unique request identifier.

api_key Optional[str]

Roboflow API Key that will be passed to the model during initialization for artifact retrieval.

start Optional[float]

start time of request

disable_model_monitoring Optional[bool]

If true, disables model monitoring for this request.

Source code in inference/core/entities/requests/inference.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class BaseRequest(BaseModel):
    """Base request for inference.

    Attributes:
        id (str_): A unique request identifier.
        api_key (Optional[str]): Roboflow API Key that will be passed to the model during initialization for artifact retrieval.
        start (Optional[float]): start time of request
        disable_model_monitoring (Optional[bool]): If true, disables model monitoring for this request.
    """

    def __init__(self, **kwargs):
        kwargs["id"] = kwargs.get("id", str(uuid4()))
        super().__init__(**kwargs)

    model_config = ConfigDict(protected_namespaces=())
    id: str
    api_key: Optional[str] = ApiKey
    usage_billable: bool = True
    start: Optional[float] = None
    source: Optional[str] = None
    source_info: Optional[str] = None
    disable_model_monitoring: Optional[bool] = Field(
        default=False, description="If true, disables model monitoring for this request"
    )

CVInferenceRequest

Bases: InferenceRequest

Computer Vision inference request.

Attributes:

Name Type Description
image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) for inference.

disable_preproc_auto_orient Optional[bool]

If true, the auto orient preprocessing step is disabled for this call. Default is False.

disable_preproc_contrast Optional[bool]

If true, the auto contrast preprocessing step is disabled for this call. Default is False.

disable_preproc_grayscale Optional[bool]

If true, the grayscale preprocessing step is disabled for this call. Default is False.

disable_preproc_static_crop Optional[bool]

If true, the static crop preprocessing step is disabled for this call. Default is False.

Source code in inference/core/entities/requests/inference.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class CVInferenceRequest(InferenceRequest):
    """Computer Vision inference request.

    Attributes:
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference.
        disable_preproc_auto_orient (Optional[bool]): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (Optional[bool]): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (Optional[bool]): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (Optional[bool]): If true, the static crop preprocessing step is disabled for this call. Default is False.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    disable_preproc_auto_orient: Optional[bool] = Field(
        default=False,
        description="If true, the auto orient preprocessing step is disabled for this call.",
    )
    disable_preproc_contrast: Optional[bool] = Field(
        default=False,
        description="If true, the auto contrast preprocessing step is disabled for this call.",
    )
    disable_preproc_grayscale: Optional[bool] = Field(
        default=False,
        description="If true, the grayscale preprocessing step is disabled for this call.",
    )
    disable_preproc_static_crop: Optional[bool] = Field(
        default=False,
        description="If true, the static crop preprocessing step is disabled for this call.",
    )

ClassificationInferenceRequest

Bases: CVInferenceRequest

Classification inference request.

Attributes:

Name Type Description
confidence Optional[float]

The confidence threshold used to filter out predictions.

visualization_stroke_width Optional[int]

The stroke width used when visualizing predictions.

visualize_predictions Optional[bool]

If true, the predictions will be drawn on the original image and returned as a base64 string.

Source code in inference/core/entities/requests/inference.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class ClassificationInferenceRequest(CVInferenceRequest):
    """Classification inference request.

    Attributes:
        confidence (Optional[float]): The confidence threshold used to filter out predictions.
        visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions.
        visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string.
    """

    def __init__(self, **kwargs):
        kwargs["model_type"] = "classification"
        super().__init__(**kwargs)

    confidence: Optional[float] = Field(
        default=0.4,
        examples=[0.5],
        description="The confidence threshold used to filter out predictions",
    )
    visualization_stroke_width: Optional[int] = Field(
        default=1,
        examples=[1],
        description="The stroke width used when visualizing predictions",
    )
    visualize_predictions: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the predictions will be drawn on the original image and returned as a base64 string",
    )
    disable_active_learning: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)",
    )
    active_learning_target_dataset: Optional[str] = Field(
        default=None,
        examples=["my_dataset"],
        description="Parameter to be used when Active Learning data registration should happen against different dataset than the one pointed by model_id",
    )

DepthEstimationRequest

Bases: InferenceRequest

Request for depth estimation.

Attributes:

Name Type Description
image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) to be estimated.

model_id str

The model ID to use for depth estimation.

depth_version_id Optional[str]

The version ID of the depth estimation model.

Source code in inference/core/entities/requests/inference.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class DepthEstimationRequest(InferenceRequest):
    """Request for depth estimation.

    Attributes:
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) to be estimated.
        model_id (str): The model ID to use for depth estimation.
        depth_version_id (Optional[str]): The version ID of the depth estimation model.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    model_id: Optional[str] = Field(None)
    depth_version_id: Optional[str] = Field(
        default="small",
        examples=["small"],
        description="The version ID of the depth estimation model",
    )

    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("depth_version_id") is None:
            return None
        return f"depth-anything-v2/{values['depth_version_id']}"

InferenceRequest

Bases: BaseRequest

Base request for inference.

Attributes:

Name Type Description
model_id str

A unique model identifier.

model_type Optional[str]

The type of the model, usually referring to what task the model performs.

Source code in inference/core/entities/requests/inference.py
35
36
37
38
39
40
41
42
43
44
class InferenceRequest(BaseRequest):
    """Base request for inference.

    Attributes:
        model_id (str): A unique model identifier.
        model_type (Optional[str]): The type of the model, usually referring to what task the model performs.
    """

    model_id: Optional[str] = ModelID
    model_type: Optional[str] = ModelType

InferenceRequestImage

Bases: BaseModel

Image data for inference request.

Attributes:

Name Type Description
type str

The type of image data provided, one of 'url', 'base64', or 'numpy'.

value Optional[Any]

Image data corresponding to the image type.

Source code in inference/core/entities/requests/inference.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class InferenceRequestImage(BaseModel):
    """Image data for inference request.

    Attributes:
        type (str): The type of image data provided, one of 'url', 'base64', or 'numpy'.
        value (Optional[Any]): Image data corresponding to the image type.
    """

    type: str = Field(
        examples=["url"],
        description="The type of image data provided, one of 'url', 'base64', or 'numpy'",
    )
    value: Optional[Any] = Field(
        None,
        examples=["http://www.example-image-url.com"],
        description="Image data corresponding to the image type, if type = 'url' then value is a string containing the url of an image, else if type = 'base64' then value is a string containing base64 encoded image data, else if type = 'numpy' then value is binary numpy data serialized using pickle.dumps(); array should 3 dimensions, channels last, with values in the range [0,255].",
    )

InstanceSegmentationInferenceRequest

Bases: ObjectDetectionInferenceRequest

Instance Segmentation inference request.

Attributes:

Name Type Description
mask_decode_mode Optional[str]

The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'.

tradeoff_factor Optional[float]

The amount to tradeoff between 0='fast' and 1='accurate'.

Source code in inference/core/entities/requests/inference.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class InstanceSegmentationInferenceRequest(ObjectDetectionInferenceRequest):
    """Instance Segmentation inference request.

    Attributes:
        mask_decode_mode (Optional[str]): The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'.
        tradeoff_factor (Optional[float]): The amount to tradeoff between 0='fast' and 1='accurate'.
    """

    mask_decode_mode: Optional[str] = Field(
        default="accurate",
        examples=["accurate"],
        description="The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'",
    )
    tradeoff_factor: Optional[float] = Field(
        default=0.0,
        examples=[0.5],
        description="The amount to tradeoff between 0='fast' and 1='accurate'",
    )

ObjectDetectionInferenceRequest

Bases: CVInferenceRequest

Object Detection inference request.

Attributes:

Name Type Description
class_agnostic_nms Optional[bool]

If true, NMS is applied to all detections at once, if false, NMS is applied per class.

class_filter Optional[List[str]]

If provided, only predictions for the listed classes will be returned.

confidence Optional[float]

The confidence threshold used to filter out predictions.

fix_batch_size Optional[bool]

If true, the batch size will be fixed to the maximum batch size configured for this server.

iou_threshold Optional[float]

The IoU threshold that must be met for a box pair to be considered duplicate during NMS.

max_detections Optional[int]

The maximum number of detections that will be returned.

max_candidates Optional[int]

The maximum number of candidate detections passed to NMS.

visualization_labels Optional[bool]

If true, labels will be rendered on prediction visualizations.

visualization_stroke_width Optional[int]

The stroke width used when visualizing predictions.

visualize_predictions Optional[bool]

If true, the predictions will be drawn on the original image and returned as a base64 string.

Source code in inference/core/entities/requests/inference.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class ObjectDetectionInferenceRequest(CVInferenceRequest):
    """Object Detection inference request.

    Attributes:
        class_agnostic_nms (Optional[bool]): If true, NMS is applied to all detections at once, if false, NMS is applied per class.
        class_filter (Optional[List[str]]): If provided, only predictions for the listed classes will be returned.
        confidence (Optional[float]): The confidence threshold used to filter out predictions.
        fix_batch_size (Optional[bool]): If true, the batch size will be fixed to the maximum batch size configured for this server.
        iou_threshold (Optional[float]): The IoU threshold that must be met for a box pair to be considered duplicate during NMS.
        max_detections (Optional[int]): The maximum number of detections that will be returned.
        max_candidates (Optional[int]): The maximum number of candidate detections passed to NMS.
        visualization_labels (Optional[bool]): If true, labels will be rendered on prediction visualizations.
        visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions.
        visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string.
    """

    class_agnostic_nms: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, NMS is applied to all detections at once, if false, NMS is applied per class",
    )
    class_filter: Optional[List[str]] = Field(
        default=None,
        examples=[["class-1", "class-2", "class-n"]],
        description="If provided, only predictions for the listed classes will be returned",
    )
    confidence: Optional[float] = Field(
        default=0.4,
        examples=[0.5],
        description="The confidence threshold used to filter out predictions",
    )
    fix_batch_size: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the batch size will be fixed to the maximum batch size configured for this server",
    )
    iou_threshold: Optional[float] = Field(
        default=0.3,
        examples=[0.5],
        description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS",
    )
    max_detections: Optional[int] = Field(
        default=300,
        examples=[300],
        description="The maximum number of detections that will be returned",
    )
    max_candidates: Optional[int] = Field(
        default=3000,
        description="The maximum number of candidate detections passed to NMS",
    )
    visualization_labels: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, labels will be rendered on prediction visualizations",
    )
    visualization_stroke_width: Optional[int] = Field(
        default=1,
        examples=[1],
        description="The stroke width used when visualizing predictions",
    )
    visualize_predictions: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the predictions will be drawn on the original image and returned as a base64 string",
    )
    disable_active_learning: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)",
    )
    active_learning_target_dataset: Optional[str] = Field(
        default=None,
        examples=["my_dataset"],
        description="Parameter to be used when Active Learning data registration should happen against different dataset than the one pointed by model_id",
    )

SemanticSegmentationInferenceRequest

Bases: CVInferenceRequest

Semantic Segmentation inference request.

Source code in inference/core/entities/requests/inference.py
227
228
229
230
231
232
class SemanticSegmentationInferenceRequest(CVInferenceRequest):
    """Semantic Segmentation inference request."""

    def __init__(self, **kwargs):
        kwargs["model_type"] = "semantic-segmentation"
        super().__init__(**kwargs)

Functions

request_from_type

request_from_type(model_type, request_dict)

Uses original request id

Source code in inference/core/entities/requests/inference.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def request_from_type(model_type, request_dict):
    """Uses original request id"""
    if model_type == "classification":
        request = ClassificationInferenceRequest(**request_dict)
    elif model_type == "instance-segmentation":
        request = InstanceSegmentationInferenceRequest(**request_dict)
    elif model_type == "object-detection":
        request = ObjectDetectionInferenceRequest(**request_dict)
    elif model_type == "semantic-segmentation":
        request = SemanticSegmentationInferenceRequest(**request_dict)
    else:
        raise ValueError(f"Unknown task type {model_type}")
    request.id = request_dict.get("id", request.id)
    return request

inference.core.entities.requests.moondream2

Classes

Moondream2InferenceRequest

Bases: DynamicClassBaseInferenceRequest

Request for Moondream 2 zero-shot predictions.

Attributes:

Name Type Description
text List[str]

A list of strings.

Source code in inference/core/entities/requests/moondream2.py
 6
 7
 8
 9
10
11
12
13
class Moondream2InferenceRequest(DynamicClassBaseInferenceRequest):
    """Request for Moondream 2 zero-shot predictions.

    Attributes:
        text (List[str]): A list of strings.
    """

    prompt: str

inference.core.entities.requests.owlv2

Classes

OwlV2InferenceRequest

Bases: BaseRequest

Request for gaze detection inference.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

owlv2_version_id Optional[str]

The version ID of Gaze to be used for this request.

image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) for inference.

training_data List[TrainingImage]

Training data to ground the model on

confidence float

Confidence threshold to filter predictions by

Source code in inference/core/entities/requests/owlv2.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class OwlV2InferenceRequest(BaseRequest):
    """Request for gaze detection inference.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        owlv2_version_id (Optional[str]): The version ID of Gaze to be used for this request.
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference.
        training_data (List[TrainingImage]): Training data to ground the model on
        confidence (float): Confidence threshold to filter predictions by
    """

    owlv2_version_id: Optional[str] = Field(
        default=OWLV2_VERSION_ID,
        examples=["owlv2-base-patch16-ensemble"],
        description="The version ID of owlv2 to be used for this request.",
    )
    model_id: Optional[str] = Field(
        default=None, description="Model id to be used in the request."
    )

    image: Union[List[InferenceRequestImage], InferenceRequestImage] = Field(
        description="Images to run the model on"
    )
    training_data: List[TrainingImage] = Field(
        description="Training images for the owlvit model to learn form"
    )
    confidence: Optional[float] = Field(
        default=0.99,
        examples=[0.99],
        description="Default confidence threshold for owlvit predictions. "
        "Needs to be much higher than you're used to, probably 0.99 - 0.9999",
    )
    visualize_predictions: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, return visualized predictions as a base64 string",
    )
    visualization_labels: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, labels will be rendered on prediction visualizations",
    )
    visualization_stroke_width: Optional[int] = Field(
        default=1,
        examples=[1],
        description="The stroke width used when visualizing predictions",
    )
    visualize_predictions: Optional[bool] = Field(
        default=False,
        examples=[False],
        description="If true, the predictions will be drawn on the original image and returned as a base64 string",
    )

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True, allow_reuse=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("owl2_version_id") is None:
            return None
        return f"google/{values['owl2_version_id']}"

inference.core.entities.requests.perception_encoder

Classes

PerceptionEncoderCompareRequest

Bases: PerceptionEncoderInferenceRequest

Request for PERCEPTION_ENCODER comparison.

Attributes:

Name Type Description
subject Union[InferenceRequestImage, str]

The type of image data provided, one of 'url' or 'base64'.

subject_type str

The type of subject, one of 'image' or 'text'.

prompt Union[List[InferenceRequestImage], InferenceRequestImage, str, List[str], Dict[str, Union[InferenceRequestImage, str]]]

The prompt for comparison.

prompt_type str

The type of prompt, one of 'image' or 'text'.

Source code in inference/core/entities/requests/perception_encoder.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class PerceptionEncoderCompareRequest(PerceptionEncoderInferenceRequest):
    """Request for PERCEPTION_ENCODER comparison.

    Attributes:
        subject (Union[InferenceRequestImage, str]): The type of image data provided, one of 'url' or 'base64'.
        subject_type (str): The type of subject, one of 'image' or 'text'.
        prompt (Union[List[InferenceRequestImage], InferenceRequestImage, str, List[str], Dict[str, Union[InferenceRequestImage, str]]]): The prompt for comparison.
        prompt_type (str): The type of prompt, one of 'image' or 'text'.
    """

    subject: Union[InferenceRequestImage, str] = Field(
        examples=["url"],
        description="The type of image data provided, one of 'url' or 'base64'",
    )
    subject_type: str = Field(
        default="image",
        examples=["image"],
        description="The type of subject, one of 'image' or 'text'",
    )
    prompt: Union[
        List[InferenceRequestImage],
        InferenceRequestImage,
        str,
        List[str],
        Dict[str, Union[InferenceRequestImage, str]],
    ]
    prompt_type: str = Field(
        default="text",
        examples=["text"],
        description="The type of prompt, one of 'image' or 'text'",
    )

PerceptionEncoderImageEmbeddingRequest

Bases: PerceptionEncoderInferenceRequest

Request for PERCEPTION_ENCODER image embedding.

Attributes:

Name Type Description
image Union[List[InferenceRequestImage], InferenceRequestImage]

Image(s) to be embedded.

Source code in inference/core/entities/requests/perception_encoder.py
38
39
40
41
42
43
44
45
class PerceptionEncoderImageEmbeddingRequest(PerceptionEncoderInferenceRequest):
    """Request for PERCEPTION_ENCODER image embedding.

    Attributes:
        image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) to be embedded.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]

PerceptionEncoderInferenceRequest

Bases: BaseRequest

Request for PERCEPTION_ENCODER inference.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

clip_version_id Optional[str]

The version ID of PERCEPTION_ENCODER to be used for this request.

Source code in inference/core/entities/requests/perception_encoder.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class PerceptionEncoderInferenceRequest(BaseRequest):
    """Request for PERCEPTION_ENCODER inference.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        clip_version_id (Optional[str]): The version ID of PERCEPTION_ENCODER to be used for this request.
    """

    perception_encoder_version_id: Optional[str] = Field(
        default=PERCEPTION_ENCODER_VERSION_ID,
        examples=["PE-Core-L14-336"],
        description="The version ID of PERCEPTION_ENCODER to be used for this request. Must be one of RN101, RN50, RN50x16, RN50x4, RN50x64, ViT-B-16, ViT-B-32, ViT-L-14-336px, and ViT-L-14.",
    )
    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("perception_encoder_version_id") is None:
            return None
        return f"perception_encoder/{values['perception_encoder_version_id']}"

PerceptionEncoderTextEmbeddingRequest

Bases: PerceptionEncoderInferenceRequest

Request for PERCEPTION_ENCODER text embedding.

Attributes:

Name Type Description
text Union[List[str], str]

A string or list of strings.

Source code in inference/core/entities/requests/perception_encoder.py
48
49
50
51
52
53
54
55
56
57
58
class PerceptionEncoderTextEmbeddingRequest(PerceptionEncoderInferenceRequest):
    """Request for PERCEPTION_ENCODER text embedding.

    Attributes:
        text (Union[List[str], str]): A string or list of strings.
    """

    text: Union[List[str], str] = Field(
        examples=["The quick brown fox jumps over the lazy dog"],
        description="A string or list of strings",
    )

inference.core.entities.requests.sam

Classes

SamEmbeddingRequest

Bases: SamInferenceRequest

SAM embedding request.

Attributes:

Name Type Description
image Optional[InferenceRequestImage]

The image to be embedded.

image_id Optional[str]

The ID of the image to be embedded used to cache the embedding.

format Optional[str]

The format of the response. Must be one of json or binary.

Source code in inference/core/entities/requests/sam.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class SamEmbeddingRequest(SamInferenceRequest):
    """SAM embedding request.

    Attributes:
        image (Optional[inference.core.entities.requests.inference.InferenceRequestImage]): The image to be embedded.
        image_id (Optional[str]): The ID of the image to be embedded used to cache the embedding.
        format (Optional[str]): The format of the response. Must be one of json or binary.
    """

    image: Optional[InferenceRequestImage] = Field(
        default=None,
        description="The image to be embedded",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be embedded used to cache the embedding.",
    )
    format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the response. Must be one of json or binary. If binary, embedding is returned as a binary numpy array.",
    )

SamInferenceRequest

Bases: BaseRequest

SAM inference request.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

sam_version_id Optional[str]

The version ID of SAM to be used for this request.

Source code in inference/core/entities/requests/sam.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class SamInferenceRequest(BaseRequest):
    """SAM inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        sam_version_id (Optional[str]): The version ID of SAM to be used for this request.
    """

    sam_version_id: Optional[str] = Field(
        default=SAM_VERSION_ID,
        examples=["vit_h"],
        description="The version ID of SAM to be used for this request. Must be one of vit_h, vit_l, or vit_b.",
    )

    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("sam_version_id") is None:
            return None
        return f"sam/{values['sam_version_id']}"

SamSegmentationRequest

Bases: SamInferenceRequest

SAM segmentation request.

Attributes:

Name Type Description
embeddings Optional[Union[List[List[List[List[float]]]], Any]]

The embeddings to be decoded.

embeddings_format Optional[str]

The format of the embeddings.

format Optional[str]

The format of the response.

image Optional[InferenceRequestImage]

The image to be segmented.

image_id Optional[str]

The ID of the image to be segmented used to retrieve cached embeddings.

has_mask_input Optional[bool]

Whether or not the request includes a mask input.

mask_input Optional[Union[List[List[List[float]]], Any]]

The set of output masks.

mask_input_format Optional[str]

The format of the mask input.

orig_im_size Optional[List[int]]

The original size of the image used to generate the embeddings.

point_coords Optional[List[List[float]]]

The coordinates of the interactive points used during decoding.

point_labels Optional[List[float]]

The labels of the interactive points used during decoding.

use_mask_input_cache Optional[bool]

Whether or not to use the mask input cache.

Source code in inference/core/entities/requests/sam.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class SamSegmentationRequest(SamInferenceRequest):
    """SAM segmentation request.

    Attributes:
        embeddings (Optional[Union[List[List[List[List[float]]]], Any]]): The embeddings to be decoded.
        embeddings_format (Optional[str]): The format of the embeddings.
        format (Optional[str]): The format of the response.
        image (Optional[InferenceRequestImage]): The image to be segmented.
        image_id (Optional[str]): The ID of the image to be segmented used to retrieve cached embeddings.
        has_mask_input (Optional[bool]): Whether or not the request includes a mask input.
        mask_input (Optional[Union[List[List[List[float]]], Any]]): The set of output masks.
        mask_input_format (Optional[str]): The format of the mask input.
        orig_im_size (Optional[List[int]]): The original size of the image used to generate the embeddings.
        point_coords (Optional[List[List[float]]]): The coordinates of the interactive points used during decoding.
        point_labels (Optional[List[float]]): The labels of the interactive points used during decoding.
        use_mask_input_cache (Optional[bool]): Whether or not to use the mask input cache.
    """

    embeddings: Optional[Union[List[List[List[List[float]]]], Any]] = Field(
        None,
        examples=["[[[[0.1, 0.2, 0.3, ...] ...] ...]]"],
        description="The embeddings to be decoded. The dimensions of the embeddings are 1 x 256 x 64 x 64. If embeddings is not provided, image must be provided.",
    )
    embeddings_format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the embeddings. Must be one of json or binary. If binary, embeddings are expected to be a binary numpy array.",
    )
    format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the response. Must be one of json or binary. If binary, masks are returned as binary numpy arrays. If json, masks are converted to polygons, then returned as json.",
    )
    image: Optional[InferenceRequestImage] = Field(
        default=None,
        description="The image to be segmented. Only required if embeddings are not provided.",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be segmented used to retrieve cached embeddings. If an embedding is cached, it will be used instead of generating a new embedding. If no embedding is cached, a new embedding will be generated and cached.",
    )
    has_mask_input: Optional[bool] = Field(
        default=False,
        examples=[True],
        description="Whether or not the request includes a mask input. If true, the mask input must be provided.",
    )
    mask_input: Optional[Union[List[List[List[float]]], Any]] = Field(
        default=None,
        description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are 256 x 256. This is the same as the output, low resolution mask from the previous inference.",
    )
    mask_input_format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the mask input. Must be one of json or binary. If binary, mask input is expected to be a binary numpy array.",
    )
    orig_im_size: Optional[List[int]] = Field(
        default=None,
        examples=[[640, 320]],
        description="The original size of the image used to generate the embeddings. This is only required if the image is not provided.",
    )
    point_coords: Optional[List[List[float]]] = Field(
        default=[[0.0, 0.0]],
        examples=[[[10.0, 10.0]]],
        description="The coordinates of the interactive points used during decoding. Each point (x,y pair) corresponds to a label in point_labels.",
    )
    point_labels: Optional[List[float]] = Field(
        default=[-1],
        examples=[[1]],
        description="The labels of the interactive points used during decoding. A 1 represents a positive point (part of the object to be segmented). A -1 represents a negative point (not part of the object to be segmented). Each label corresponds to a point in point_coords.",
    )
    use_mask_input_cache: Optional[bool] = Field(
        default=True,
        examples=[True],
        description="Whether or not to use the mask input cache. If true, the mask input cache will be used if it exists. If false, the mask input cache will not be used.",
    )

inference.core.entities.requests.sam2

Classes

Sam2EmbeddingRequest

Bases: Sam2InferenceRequest

SAM embedding request.

Attributes:

Name Type Description
image Optional[InferenceRequestImage]

The image to be embedded.

image_id Optional[str]

The ID of the image to be embedded used to cache the embedding.

format Optional[str]

The format of the response. Must be one of json or binary.

Source code in inference/core/entities/requests/sam2.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Sam2EmbeddingRequest(Sam2InferenceRequest):
    """SAM embedding request.

    Attributes:
        image (Optional[inference.core.entities.requests.inference.InferenceRequestImage]): The image to be embedded.
        image_id (Optional[str]): The ID of the image to be embedded used to cache the embedding.
        format (Optional[str]): The format of the response. Must be one of json or binary.
    """

    image: Optional[InferenceRequestImage] = Field(
        default=None,
        description="The image to be embedded",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be embedded used to cache the embedding.",
    )

Sam2InferenceRequest

Bases: BaseRequest

SAM2 inference request.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

sam2_version_id Optional[str]

The version ID of SAM2 to be used for this request.

Source code in inference/core/entities/requests/sam2.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Sam2InferenceRequest(BaseRequest):
    """SAM2 inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        sam2_version_id (Optional[str]): The version ID of SAM2 to be used for this request.
    """

    sam2_version_id: Optional[str] = Field(
        default=SAM2_VERSION_ID,
        examples=["hiera_large"],
        description="The version ID of SAM to be used for this request. Must be one of hiera_tiny, hiera_small, hiera_large, hiera_b_plus",
    )

    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("sam_version_id") is None:
            return None
        return f"sam2/{values['sam_version_id']}"

Sam2SegmentationRequest

Bases: Sam2InferenceRequest

SAM segmentation request.

Attributes:

Name Type Description
format Optional[str]

The format of the response.

image InferenceRequestImage

The image to be segmented.

image_id Optional[str]

The ID of the image to be segmented used to retrieve cached embeddings.

point_coords Optional[List[List[float]]]

The coordinates of the interactive points used during decoding.

point_labels Optional[List[float]]

The labels of the interactive points used during decoding.

Source code in inference/core/entities/requests/sam2.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class Sam2SegmentationRequest(Sam2InferenceRequest):
    """SAM segmentation request.

    Attributes:
        format (Optional[str]): The format of the response.
        image (InferenceRequestImage): The image to be segmented.
        image_id (Optional[str]): The ID of the image to be segmented used to retrieve cached embeddings.
        point_coords (Optional[List[List[float]]]): The coordinates of the interactive points used during decoding.
        point_labels (Optional[List[float]]): The labels of the interactive points used during decoding.
    """

    format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the response. Must be one of 'json', 'rle', or 'binary'. If binary, masks are returned as binary numpy arrays. If json, masks are converted to polygons. If rle, masks are converted to RLE format.",
    )
    image: InferenceRequestImage = Field(
        description="The image to be segmented.",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be segmented used to retrieve cached embeddings. If an embedding is cached, it will be used instead of generating a new embedding. If no embedding is cached, a new embedding will be generated and cached.",
    )
    prompts: Sam2PromptSet = Field(
        default=Sam2PromptSet(prompts=None),
        example=[{"prompts": [{"points": [{"x": 100, "y": 100, "positive": True}]}]}],
        description="A list of prompts for masks to predict. Each prompt can include a bounding box and / or a set of postive or negative points. "
        "Also accepts a flat array of prompts (e.g. 'prompts': [{...}, {...}]) for convenience.",
    )
    multimask_output: bool = Field(
        default=True,
        examples=[True],
        description="If true, the model will return three masks. "
        "For ambiguous input prompts (such as a single click), this will often "
        "produce better masks than a single prediction. If only a single "
        "mask is needed, the model's predicted quality score can be used "
        "to select the best mask. For non-ambiguous prompts, such as multiple "
        "input prompts, multimask_output=False can give better results.",
    )

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("prompts", pre=True, always=True)
    def _coerce_prompts(cls, value):
        """
        Accepts any of the following and coerces to Sam2PromptSet:
        - None
        - Sam2PromptSet
        - {"prompts": [...]} (nested)
        - [...] (flat list of prompts)
        - single prompt dict (wrapped to list)
        """
        if value is None:
            return Sam2PromptSet(prompts=None)
        if isinstance(value, Sam2PromptSet):
            return value
        # Nested dict with key 'prompts'
        if isinstance(value, dict):
            if "prompts" in value:
                return Sam2PromptSet(**value)
            # Single prompt dict – wrap and parse
            try:
                return Sam2PromptSet(prompts=[Sam2Prompt(**value)])
            except Exception:
                # Fall-through to attempt generic construction
                return Sam2PromptSet(**value)
        # Flat list of prompts (dicts or Sam2Prompt instances)
        if isinstance(value, list):
            prompts: List[Sam2Prompt] = []
            for item in value:
                if isinstance(item, Sam2Prompt):
                    prompts.append(item)
                elif isinstance(item, dict):
                    prompts.append(Sam2Prompt(**item))
                else:
                    raise ValueError(
                        "Invalid prompt entry; expected dict or Sam2Prompt instance"
                    )
            return Sam2PromptSet(prompts=prompts)
        # Fallback: let Pydantic try
        return value

    save_logits_to_cache: bool = Field(
        default=False,
        description="If True, saves the low-resolution logits to the cache for potential future use. "
        "This can speed up subsequent requests with similar prompts on the same image. "
        "This feature is ignored if DISABLE_SAM2_LOGITS_CACHE env variable is set True",
    )
    load_logits_from_cache: bool = Field(
        default=False,
        description="If True, attempts to load previously cached low-resolution logits for the given image and prompt set. "
        "This can significantly speed up inference when making multiple similar requests on the same image. "
        "This feature is ignored if DISABLE_SAM2_LOGITS_CACHE env variable is set True",
    )

inference.core.entities.requests.sam3

Classes

Sam3InferenceRequest

Bases: BaseRequest

SAM3 inference request.

Attributes:

Name Type Description
model_id Optional[str]

The model ID to be used, typically sam3.

Source code in inference/core/entities/requests/sam3.py
76
77
78
79
80
81
82
83
84
85
86
class Sam3InferenceRequest(BaseRequest):
    """SAM3 inference request.

    Attributes:
        model_id (Optional[str]): The model ID to be used, typically `sam3`.
    """

    model_id: Optional[str] = Field(
        default="sam3/sam3_final",
        description="The model ID of SAM3. Use 'sam3/sam3_final' to target the generic base model.",
    )

Sam3Prompt

Bases: BaseModel

Unified prompt that can contain text and/or geometry.

Absolute pixel coordinates are used for boxes. Labels accept 0/1 or booleans.

Source code in inference/core/entities/requests/sam3.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Sam3Prompt(BaseModel):
    """Unified prompt that can contain text and/or geometry.

    Absolute pixel coordinates are used for boxes. Labels accept 0/1 or booleans.
    """

    type: Optional[str] = Field(
        default=None, description="Optional hint: 'text' or 'visual'"
    )
    text: Optional[str] = Field(default=None)

    output_prob_thresh: Optional[float] = Field(
        default=None,
        description="Score threshold for this prompt's outputs. Overrides request-level threshold if set.",
    )

    # Absolute-coordinate boxes (preferred) in pixels.
    # XYWH absolute pixels
    class Box(BaseModel):
        x: float
        y: float
        width: float
        height: float

    # XYXY absolute pixels
    class BoxXYXY(BaseModel):
        x0: float
        y0: float
        x1: float
        y1: float

    # Single unified boxes field; each entry can be XYWH or XYXY
    boxes: Optional[List[Union[Box, BoxXYXY]]] = Field(
        default=None,
        description="Absolute pixel boxes as either XYWH or XYXY entries",
    )
    box_labels: Optional[List[Union[int, bool]]] = Field(
        default=None, description="List of 0/1 or booleans for boxes"
    )

    @validator("boxes", always=True)
    def _validate_visual_boxes(cls, boxes, values):
        prompt_type = values.get("type")
        if prompt_type == "visual":
            if not boxes or len(boxes) == 0:
                raise ValueError("Visual prompt requires at least one box")
        return boxes

    @validator("box_labels", always=True)
    def _validate_box_labels(cls, labels, values):
        boxes = values.get("boxes")
        if labels is None:
            return labels
        if boxes is None or len(labels) != len(boxes):
            raise ValueError("box_labels must match boxes length when provided")
        return labels

    @validator("output_prob_thresh")
    def _validate_output_prob_thresh(cls, v):
        if v is not None and (v < 0.0 or v > 1.0):
            raise ValueError("output_prob_thresh must be between 0.0 and 1.0")
        return v

inference.core.entities.requests.sam3_3d

Classes

Sam3_3D_Objects_InferenceRequest

Bases: BaseRequest

SAM3D inference request for 3D object generation.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

image InferenceRequestImage

The input image to be used for 3D generation.

mask_input Any

Mask(s) in any supported format - polygon, binary mask, or RLE.

Source code in inference/core/entities/requests/sam3_3d.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Sam3_3D_Objects_InferenceRequest(BaseRequest):
    """SAM3D inference request for 3D object generation.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        image (InferenceRequestImage): The input image to be used for 3D generation.
        mask_input: Mask(s) in any supported format - polygon, binary mask, or RLE.
    """

    image: InferenceRequestImage = Field(
        description="The input image to be used for 3D generation.",
    )

    mask_input: Any = Field(
        description="Mask input in any supported format: "
        "polygon [x1,y1,x2,y2,...], binary mask (base64), RLE dict, or list of these.",
    )

    model_id: Optional[str] = Field(
        default="sam3-3d-objects", description="The model ID for SAM3_3D."
    )

    output_meshes: Optional[bool] = Field(
        default=True,
        description="SAM3 3D always outputs object gaussians, and can optionally output object meshes if output_meshes is True.",
    )

    output_scene: Optional[bool] = Field(
        default=True,
        description="Output the combined scene reconstruction in addition to individual object reconstructions.",
    )

    with_mesh_postprocess: Optional[bool] = Field(
        default=True, description="Enable mesh postprocessing."
    )

    with_texture_baking: Optional[bool] = Field(
        default=True, description="Enable texture baking for meshes."
    )

    use_distillations: Optional[bool] = Field(
        default=False, description="Use the distilled versions of the model components."
    )

    @validator("model_id", always=True)
    def validate_model_id(cls, value):
        if value is not None:
            return value
        return "sam3-3d-objects"

inference.core.entities.requests.server_state

Classes

AddModelRequest

Bases: BaseModel

Request to add a model to the inference server.

Attributes:

Name Type Description
model_id str

A unique model identifier.

model_type Optional[str]

The type of the model, usually referring to what task the model performs.

api_key Optional[str]

Roboflow API Key that will be passed to the model during initialization for artifact retrieval.

Source code in inference/core/entities/requests/server_state.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class AddModelRequest(BaseModel):
    """Request to add a model to the inference server.

    Attributes:
        model_id (str): A unique model identifier.
        model_type (Optional[str]): The type of the model, usually referring to what task the model performs.
        api_key (Optional[str]): Roboflow API Key that will be passed to the model during initialization for artifact retrieval.
    """

    model_config = ConfigDict(protected_namespaces=())
    model_id: str = ModelID
    model_type: Optional[str] = ModelType
    api_key: Optional[str] = ApiKey

ClearModelRequest

Bases: BaseModel

Request to clear a model from the inference server.

Attributes:

Name Type Description
model_id str

A unique model identifier.

Source code in inference/core/entities/requests/server_state.py
23
24
25
26
27
28
29
30
31
class ClearModelRequest(BaseModel):
    """Request to clear a model from the inference server.

    Attributes:
        model_id (str): A unique model identifier.
    """

    model_config = ConfigDict(protected_namespaces=())
    model_id: str = ModelID

inference.core.entities.requests.trocr

Classes

TrOCRInferenceRequest

Bases: BaseRequest

TrOCR inference request.

Attributes:

Name Type Description
api_key Optional[str]

Roboflow API Key.

Source code in inference/core/entities/requests/trocr.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class TrOCRInferenceRequest(BaseRequest):
    """
    TrOCR inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
    """

    image: Union[List[InferenceRequestImage], InferenceRequestImage]
    trocr_version_id: Optional[str] = "trocr-base-printed"
    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True, allow_reuse=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("trocr_version_id") is None:
            return None
        return f"trocr/{values['trocr_version_id']}"

inference.core.entities.requests.yolo_world

Classes

YOLOWorldInferenceRequest

Bases: DynamicClassBaseInferenceRequest

Request for Grounding DINO zero-shot predictions.

Attributes:

Name Type Description
text List[str]

A list of strings.

Source code in inference/core/entities/requests/yolo_world.py
 9
10
11
12
13
14
15
16
17
class YOLOWorldInferenceRequest(DynamicClassBaseInferenceRequest):
    """Request for Grounding DINO zero-shot predictions.

    Attributes:
        text (List[str]): A list of strings.
    """

    yolo_world_version_id: Optional[str] = "l"
    confidence: Optional[float] = DEFAULT_CONFIDENCE

core/entities/responses

inference.core.entities.responses.clip

Classes

ClipCompareResponse

Bases: InferenceResponse

Response for CLIP comparison.

Attributes:

Name Type Description
similarity Union[List[float], Dict[str, float]]

Similarity scores.

time float

The time in seconds it took to produce the similarity scores including preprocessing.

Source code in inference/core/entities/responses/clip.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class ClipCompareResponse(InferenceResponse):
    """Response for CLIP comparison.

    Attributes:
        similarity (Union[List[float], Dict[str, float]]): Similarity scores.
        time (float): The time in seconds it took to produce the similarity scores including preprocessing.
    """

    similarity: Union[List[float], Dict[str, float]]
    time: Optional[float] = Field(
        default=None,
        description="The time in seconds it took to produce the similarity scores including preprocessing",
    )
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

ClipEmbeddingResponse

Bases: InferenceResponse

Response for CLIP embedding.

Attributes:

Name Type Description
embeddings List[List[float]]

A list of embeddings, each embedding is a list of floats.

time float

The time in seconds it took to produce the embeddings including preprocessing.

Source code in inference/core/entities/responses/clip.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ClipEmbeddingResponse(InferenceResponse):
    """Response for CLIP embedding.

    Attributes:
        embeddings (List[List[float]]): A list of embeddings, each embedding is a list of floats.
        time (float): The time in seconds it took to produce the embeddings including preprocessing.
    """

    embeddings: List[List[float]] = Field(
        examples=["[[0.12, 0.23, 0.34, ..., 0.43]]"],
        description="A list of embeddings, each embedding is a list of floats",
    )
    time: Optional[float] = Field(
        default=None,
        description="The time in seconds it took to produce the embeddings including preprocessing",
    )

inference.core.entities.responses.gaze

Classes

GazeDetectionInferenceResponse

Bases: BaseModel

Response for gaze detection inference.

Attributes:

Name Type Description
predictions List[GazeDetectionPrediction]

List of gaze detection predictions.

time float

The processing time (second).

Source code in inference/core/entities/responses/gaze.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class GazeDetectionInferenceResponse(BaseModel):
    """Response for gaze detection inference.

    Attributes:
        predictions (List[inference.core.entities.responses.gaze.GazeDetectionPrediction]): List of gaze detection predictions.
        time (float): The processing time (second).
    """

    predictions: List[GazeDetectionPrediction]

    time: float = Field(description="The processing time (second)")
    time_face_det: Optional[float] = Field(
        None, description="The face detection time (second)"
    )
    time_gaze_det: Optional[float] = Field(
        None, description="The gaze detection time (second)"
    )

GazeDetectionPrediction

Bases: BaseModel

Gaze Detection prediction.

Attributes:

Name Type Description
face FaceDetectionPrediction

The face prediction.

yaw float

Yaw (radian) of the detected face.

pitch float

Pitch (radian) of the detected face.

Source code in inference/core/entities/responses/gaze.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class GazeDetectionPrediction(BaseModel):
    """Gaze Detection prediction.

    Attributes:
        face (inference.core.entities.responses.inference.FaceDetectionPrediction): The face prediction.
        yaw (float): Yaw (radian) of the detected face.
        pitch (float): Pitch (radian) of the detected face.
    """

    face: FaceDetectionPrediction

    yaw: Optional[float] = Field(description="Yaw (radian) of the detected face")
    pitch: Optional[float] = Field(description="Pitch (radian) of the detected face")

inference.core.entities.responses.inference

Classes

ClassificationInferenceResponse

Bases: CvInferenceResponse, WithVisualizationResponse

Classification inference response.

Attributes:

Name Type Description
predictions List[ClassificationPrediction]

List of classification predictions.

top str

The top predicted class label.

confidence float

The confidence of the top predicted class label.

Source code in inference/core/entities/responses/inference.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class ClassificationInferenceResponse(CvInferenceResponse, WithVisualizationResponse):
    """Classification inference response.

    Attributes:
        predictions (List[inference.core.entities.responses.inference.ClassificationPrediction]): List of classification predictions.
        top (str): The top predicted class label.
        confidence (float): The confidence of the top predicted class label.
    """

    predictions: List[ClassificationPrediction]
    top: str = Field(
        description="The top predicted class label", default=""
    )  # Not making this field optional to avoid breaking change - in other parts of the codebase `model_dump` is called with `exclude_none=True`
    confidence: float = Field(
        description="The confidence of the top predicted class label",
        default=0.0,
    )
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

ClassificationPrediction

Bases: BaseModel

Classification prediction.

Attributes:

Name Type Description
class_name str

The predicted class label.

class_id int

Numeric ID associated with the class label.

confidence float

The class label confidence as a fraction between 0 and 1.

Source code in inference/core/entities/responses/inference.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class ClassificationPrediction(BaseModel):
    """Classification prediction.

    Attributes:
        class_name (str): The predicted class label.
        class_id (int): Numeric ID associated with the class label.
        confidence (float): The class label confidence as a fraction between 0 and 1.
    """

    class_name: str = Field(alias="class", description="The predicted class label")
    class_id: int = Field(description="Numeric ID associated with the class label")
    confidence: float = Field(
        description="The class label confidence as a fraction between 0 and 1"
    )

CvInferenceResponse

Bases: InferenceResponse

Computer Vision inference response.

Attributes:

Name Type Description
image Union[List[InferenceResponseImage], InferenceResponseImage]

Image(s) used in inference.

Source code in inference/core/entities/responses/inference.py
194
195
196
197
198
199
200
201
class CvInferenceResponse(InferenceResponse):
    """Computer Vision inference response.

    Attributes:
        image (Union[List[inference.core.entities.responses.inference.InferenceResponseImage], inference.core.entities.responses.inference.InferenceResponseImage]): Image(s) used in inference.
    """

    image: Union[List[InferenceResponseImage], InferenceResponseImage]

DepthEstimationResponse

Bases: BaseModel

Response for depth estimation inference.

Attributes:

Name Type Description
normalized_depth List[List[float]]

The normalized depth map as a 2D array of floats between 0 and 1.

image Optional[str]

Base64 encoded visualization of the depth map if visualize_predictions is True.

time float

The processing time in seconds.

visualization Optional[str]

Base64 encoded visualization of the depth map if visualize_predictions is True.

Source code in inference/core/entities/responses/inference.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class DepthEstimationResponse(BaseModel):
    """Response for depth estimation inference.

    Attributes:
        normalized_depth (List[List[float]]): The normalized depth map as a 2D array of floats between 0 and 1.
        image (Optional[str]): Base64 encoded visualization of the depth map if visualize_predictions is True.
        time (float): The processing time in seconds.
        visualization (Optional[str]): Base64 encoded visualization of the depth map if visualize_predictions is True.
    """

    normalized_depth: List[List[float]] = Field(
        description="The normalized depth map as a 2D array of floats between 0 and 1"
    )
    image: Optional[str] = Field(
        None,
        description="Base64 encoded visualization of the depth map if visualize_predictions is True",
    )

FaceDetectionPrediction

Bases: ObjectDetectionPrediction

Face Detection prediction.

Attributes:

Name Type Description
class_name str

fixed value "face".

landmarks Union[List[Point], List[Point3D]]

The detected face landmarks.

Source code in inference/core/entities/responses/inference.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class FaceDetectionPrediction(ObjectDetectionPrediction):
    """Face Detection prediction.

    Attributes:
        class_name (str): fixed value "face".
        landmarks (Union[List[inference.core.entities.responses.inference.Point], List[inference.core.entities.responses.inference.Point3D]]): The detected face landmarks.
    """

    class_id: Optional[int] = Field(
        description="The class id of the prediction", default=0
    )
    class_name: str = Field(
        alias="class", default="face", description="The predicted class label"
    )
    landmarks: Union[List[Point], List[Point3D]]

InferenceResponse

Bases: BaseModel

Base inference response.

Attributes:

Name Type Description
inference_id Optional[str]

Unique identifier of inference

frame_id Optional[int]

The frame id of the image used in inference if the input was a video.

time Optional[float]

The time in seconds it took to produce the predictions including image preprocessing.

Source code in inference/core/entities/responses/inference.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class InferenceResponse(BaseModel):
    """Base inference response.

    Attributes:
        inference_id (Optional[str]): Unique identifier of inference
        frame_id (Optional[int]): The frame id of the image used in inference if the input was a video.
        time (Optional[float]): The time in seconds it took to produce the predictions including image preprocessing.
    """

    model_config = ConfigDict(protected_namespaces=())
    inference_id: Optional[str] = Field(
        description="Unique identifier of inference", default=None
    )
    frame_id: Optional[int] = Field(
        default=None,
        description="The frame id of the image used in inference if the input was a video",
    )
    time: Optional[float] = Field(
        default=None,
        description="The time in seconds it took to produce the predictions including image preprocessing",
    )

InferenceResponseImage

Bases: BaseModel

Inference response image information.

Attributes:

Name Type Description
width int

The original width of the image used in inference.

height int

The original height of the image used in inference.

Source code in inference/core/entities/responses/inference.py
157
158
159
160
161
162
163
164
165
166
167
168
class InferenceResponseImage(BaseModel):
    """Inference response image information.

    Attributes:
        width (int): The original width of the image used in inference.
        height (int): The original height of the image used in inference.
    """

    width: int = Field(description="The original width of the image used in inference")
    height: int = Field(
        description="The original height of the image used in inference"
    )

InstanceSegmentationInferenceResponse

Bases: CvInferenceResponse, WithVisualizationResponse

Instance Segmentation inference response.

Attributes:

Name Type Description
predictions List[InstanceSegmentationPrediction]

List of instance segmentation predictions.

Source code in inference/core/entities/responses/inference.py
251
252
253
254
255
256
257
258
259
260
class InstanceSegmentationInferenceResponse(
    CvInferenceResponse, WithVisualizationResponse
):
    """Instance Segmentation inference response.

    Attributes:
        predictions (List[inference.core.entities.responses.inference.InstanceSegmentationPrediction]): List of instance segmentation predictions.
    """

    predictions: List[InstanceSegmentationPrediction]

MultiLabelClassificationInferenceResponse

Bases: CvInferenceResponse, WithVisualizationResponse

Multi-label Classification inference response.

Attributes:

Name Type Description
predictions Dict[str, MultiLabelClassificationPrediction]

Dictionary of multi-label classification predictions.

predicted_classes List[str]

The list of predicted classes.

Source code in inference/core/entities/responses/inference.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
class MultiLabelClassificationInferenceResponse(
    CvInferenceResponse, WithVisualizationResponse
):
    """Multi-label Classification inference response.

    Attributes:
        predictions (Dict[str, inference.core.entities.responses.inference.MultiLabelClassificationPrediction]): Dictionary of multi-label classification predictions.
        predicted_classes (List[str]): The list of predicted classes.
    """

    predictions: Dict[str, MultiLabelClassificationPrediction]
    predicted_classes: List[str] = Field(description="The list of predicted classes")
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

MultiLabelClassificationPrediction

Bases: BaseModel

Multi-label Classification prediction.

Attributes:

Name Type Description
confidence float

The class label confidence as a fraction between 0 and 1.

Source code in inference/core/entities/responses/inference.py
144
145
146
147
148
149
150
151
152
153
154
class MultiLabelClassificationPrediction(BaseModel):
    """Multi-label Classification prediction.

    Attributes:
        confidence (float): The class label confidence as a fraction between 0 and 1.
    """

    confidence: float = Field(
        description="The class label confidence as a fraction between 0 and 1"
    )
    class_id: int = Field(description="Numeric ID associated with the class label")

ObjectDetectionInferenceResponse

Bases: CvInferenceResponse, WithVisualizationResponse

Object Detection inference response.

Attributes:

Name Type Description
predictions List[ObjectDetectionPrediction]

List of object detection predictions.

Source code in inference/core/entities/responses/inference.py
223
224
225
226
227
228
229
230
class ObjectDetectionInferenceResponse(CvInferenceResponse, WithVisualizationResponse):
    """Object Detection inference response.

    Attributes:
        predictions (List[inference.core.entities.responses.inference.ObjectDetectionPrediction]): List of object detection predictions.
    """

    predictions: List[ObjectDetectionPrediction]

ObjectDetectionPrediction

Bases: BaseModel

Object Detection prediction.

Attributes:

Name Type Description
x float

The center x-axis pixel coordinate of the prediction.

y float

The center y-axis pixel coordinate of the prediction.

width float

The width of the prediction bounding box in number of pixels.

height float

The height of the prediction bounding box in number of pixels.

confidence float

The detection confidence as a fraction between 0 and 1.

class_name str

The predicted class label.

class_confidence Union[float, None]

The class label confidence as a fraction between 0 and 1.

class_id int

The class id of the prediction

Source code in inference/core/entities/responses/inference.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ObjectDetectionPrediction(BaseModel):
    """Object Detection prediction.

    Attributes:
        x (float): The center x-axis pixel coordinate of the prediction.
        y (float): The center y-axis pixel coordinate of the prediction.
        width (float): The width of the prediction bounding box in number of pixels.
        height (float): The height of the prediction bounding box in number of pixels.
        confidence (float): The detection confidence as a fraction between 0 and 1.
        class_name (str): The predicted class label.
        class_confidence (Union[float, None]): The class label confidence as a fraction between 0 and 1.
        class_id (int): The class id of the prediction
    """

    x: float = Field(description="The center x-axis pixel coordinate of the prediction")
    y: float = Field(description="The center y-axis pixel coordinate of the prediction")
    width: float = Field(
        description="The width of the prediction bounding box in number of pixels"
    )
    height: float = Field(
        description="The height of the prediction bounding box in number of pixels"
    )
    confidence: float = Field(
        description="The detection confidence as a fraction between 0 and 1"
    )
    class_name: str = Field(alias="class", description="The predicted class label")

    class_confidence: Union[float, None] = Field(
        None, description="The class label confidence as a fraction between 0 and 1"
    )
    class_id: int = Field(description="The class id of the prediction")
    tracker_id: Optional[int] = Field(
        description="The tracker id of the prediction if tracking is enabled",
        default=None,
    )
    detection_id: str = Field(
        description="Unique identifier of detection",
        default_factory=lambda: str(uuid4()),
    )
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

Point

Bases: BaseModel

Point coordinates.

Attributes:

Name Type Description
x float

The x-axis pixel coordinate of the point.

y float

The y-axis pixel coordinate of the point.

Source code in inference/core/entities/responses/inference.py
53
54
55
56
57
58
59
60
61
62
class Point(BaseModel):
    """Point coordinates.

    Attributes:
        x (float): The x-axis pixel coordinate of the point.
        y (float): The y-axis pixel coordinate of the point.
    """

    x: float = Field(description="The x-axis pixel coordinate of the point")
    y: float = Field(description="The y-axis pixel coordinate of the point")

Point3D

Bases: Point

3D Point coordinates.

Attributes:

Name Type Description
z float

The z-axis pixel coordinate of the point.

Source code in inference/core/entities/responses/inference.py
65
66
67
68
69
70
71
72
class Point3D(Point):
    """3D Point coordinates.

    Attributes:
        z (float): The z-axis pixel coordinate of the point.
    """

    z: float = Field(description="The z-axis pixel coordinate of the point")

SemanticSegmentationInferenceResponse

Bases: CvInferenceResponse, WithVisualizationResponse

Semantic Segmentation inference response.

Attributes:

Name Type Description
predictions SemanticSegmentationPrediction

Semantic segmentation predictions.

Source code in inference/core/entities/responses/inference.py
263
264
265
266
267
268
269
270
271
272
class SemanticSegmentationInferenceResponse(
    CvInferenceResponse, WithVisualizationResponse
):
    """Semantic Segmentation inference response.

    Attributes:
        predictions (inference.core.entities.responses.inference.SemanticSegmentationPrediction): Semantic segmentation predictions.
    """

    predictions: SemanticSegmentationPrediction

WithVisualizationResponse

Bases: BaseModel

Response with visualization.

Attributes:

Name Type Description
visualization Optional[Any]

Base64 encoded string containing prediction visualization image data.

Source code in inference/core/entities/responses/inference.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class WithVisualizationResponse(BaseModel):
    """Response with visualization.

    Attributes:
        visualization (Optional[Any]): Base64 encoded string containing prediction visualization image data.
    """

    visualization: Optional[Any] = Field(
        default=None,
        description="Base64 encoded string containing prediction visualization image data",
    )

    @field_serializer("visualization", when_used="json")
    def serialize_visualisation(self, visualization: Optional[Any]) -> Optional[str]:
        if visualization is None:
            return None
        return base64.b64encode(visualization).decode("utf-8")

inference.core.entities.responses.notebooks

Classes

NotebookStartResponse

Bases: BaseModel

Response model for notebook start request

Source code in inference/core/entities/responses/notebooks.py
4
5
6
7
8
class NotebookStartResponse(BaseModel):
    """Response model for notebook start request"""

    success: str = Field(..., description="Status of the request")
    message: str = Field(..., description="Message of the request", optional=True)

inference.core.entities.responses.ocr

Classes

OCRInferenceResponse

Bases: BaseModel

OCR Inference response.

Attributes:

Name Type Description
result str

The combined OCR recognition result.

predictions List[ObjectDetectionPrediction]

List of objects detected by OCR

time float

The time in seconds it took to produce the inference including preprocessing

Source code in inference/core/entities/responses/ocr.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class OCRInferenceResponse(BaseModel):
    """
    OCR Inference response.

    Attributes:
        result (str): The combined OCR recognition result.
        predictions (List[ObjectDetectionPrediction]): List of objects detected by OCR
        time (float): The time in seconds it took to produce the inference including preprocessing
    """

    result: str = Field(description="The combined OCR recognition result.")
    image: Optional[InferenceResponseImage] = Field(
        description="Metadata about input image dimensions", default=None
    )
    predictions: Optional[List[ObjectDetectionPrediction]] = Field(
        description="List of objects detected by OCR",
        default=None,
    )
    time: float = Field(
        description="The time in seconds it took to produce the inference including preprocessing."
    )
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

inference.core.entities.responses.perception_encoder

Classes

PerceptionEncoderCompareResponse

Bases: InferenceResponse

Response for PERCEPTION_ENCODER comparison.

Attributes:

Name Type Description
similarity Union[List[float], Dict[str, float]]

Similarity scores.

time float

The time in seconds it took to produce the similarity scores including preprocessing.

Source code in inference/core/entities/responses/perception_encoder.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class PerceptionEncoderCompareResponse(InferenceResponse):
    """Response for PERCEPTION_ENCODER comparison.

    Attributes:
        similarity (Union[List[float], Dict[str, float]]): Similarity scores.
        time (float): The time in seconds it took to produce the similarity scores including preprocessing.
    """

    similarity: Union[List[float], Dict[str, float]]
    time: Optional[float] = Field(
        default=None,
        description="The time in seconds it took to produce the similarity scores including preprocessing",
    )
    parent_id: Optional[str] = Field(
        description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
        default=None,
    )

PerceptionEncoderEmbeddingResponse

Bases: InferenceResponse

Response for PERCEPTION_ENCODER embedding.

Attributes:

Name Type Description
embeddings List[List[float]]

A list of embeddings, each embedding is a list of floats.

time float

The time in seconds it took to produce the embeddings including preprocessing.

Source code in inference/core/entities/responses/perception_encoder.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PerceptionEncoderEmbeddingResponse(InferenceResponse):
    """Response for PERCEPTION_ENCODER embedding.

    Attributes:
        embeddings (List[List[float]]): A list of embeddings, each embedding is a list of floats.
        time (float): The time in seconds it took to produce the embeddings including preprocessing.
    """

    embeddings: List[List[float]] = Field(
        examples=["[[0.12, 0.23, 0.34, ..., 0.43]]"],
        description="A list of embeddings, each embedding is a list of floats",
    )
    time: Optional[float] = Field(
        None,
        description="The time in seconds it took to produce the embeddings including preprocessing",
    )

inference.core.entities.responses.sam

Classes

SamEmbeddingResponse

Bases: BaseModel

SAM embedding response.

Attributes:

Name Type Description
embeddings Union[List[List[List[List[float]]]], Any]

The SAM embedding.

time float

The time in seconds it took to produce the embeddings including preprocessing.

Source code in inference/core/entities/responses/sam.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class SamEmbeddingResponse(BaseModel):
    """SAM embedding response.

    Attributes:
        embeddings (Union[List[List[List[List[float]]]], Any]): The SAM embedding.
        time (float): The time in seconds it took to produce the embeddings including preprocessing.
    """

    embeddings: Union[List[List[List[List[float]]]], Any] = Field(
        examples=["[[[[0.1, 0.2, 0.3, ...] ...] ...]]"],
        description="If request format is json, embeddings is a series of nested lists representing the SAM embedding. If request format is binary, embeddings is a binary numpy array. The dimensions of the embedding are 1 x 256 x 64 x 64.",
    )
    time: float = Field(
        description="The time in seconds it took to produce the embeddings including preprocessing"
    )

SamSegmentationResponse

Bases: BaseModel

SAM segmentation response.

Attributes:

Name Type Description
masks Union[List[List[List[int]]], Any]

The set of output masks.

low_res_masks Union[List[List[List[int]]], Any]

The set of output low-resolution masks.

time float

The time in seconds it took to produce the segmentation including preprocessing.

Source code in inference/core/entities/responses/sam.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SamSegmentationResponse(BaseModel):
    """SAM segmentation response.

    Attributes:
        masks (Union[List[List[List[int]]], Any]): The set of output masks.
        low_res_masks (Union[List[List[List[int]]], Any]): The set of output low-resolution masks.
        time (float): The time in seconds it took to produce the segmentation including preprocessing.
    """

    masks: Union[List[List[List[int]]], Any] = Field(
        description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are the same as the dimensions of the input image.",
    )
    low_res_masks: Union[List[List[List[int]]], Any] = Field(
        description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are 256 x 256",
    )
    time: float = Field(
        description="The time in seconds it took to produce the segmentation including preprocessing"
    )

inference.core.entities.responses.sam2

Classes

Sam2EmbeddingResponse

Bases: BaseModel

SAM embedding response.

Attributes:

Name Type Description
embeddings Union[List[List[List[List[float]]]], Any]

The SAM embedding.

time float

The time in seconds it took to produce the embeddings including preprocessing.

Source code in inference/core/entities/responses/sam2.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Sam2EmbeddingResponse(BaseModel):
    """SAM embedding response.

    Attributes:
        embeddings (Union[List[List[List[List[float]]]], Any]): The SAM embedding.
        time (float): The time in seconds it took to produce the embeddings including preprocessing.
    """

    image_id: str = Field(description="Image id embeddings are cached to")
    time: float = Field(
        description="The time in seconds it took to produce the embeddings including preprocessing"
    )

Sam2SegmentationPrediction

Bases: BaseModel

SAM segmentation prediction.

Attributes:

Name Type Description
masks Union[List[List[List[int]]], Dict[str, Any], Any]

Mask data - either polygon coordinates or RLE encoding.

confidence float

Masks confidences.

format Optional[str]

Format of the mask data: 'polygon' or 'rle'.

Source code in inference/core/entities/responses/sam2.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Sam2SegmentationPrediction(BaseModel):
    """SAM segmentation prediction.

    Attributes:
        masks (Union[List[List[List[int]]], Dict[str, Any], Any]): Mask data - either polygon coordinates or RLE encoding.
        confidence (float): Masks confidences.
        format (Optional[str]): Format of the mask data: 'polygon' or 'rle'.
    """

    masks: Union[List[List[List[int]]], Dict[str, Any]] = Field(
        description="If polygon format, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If rle format, masks is a dictionary with the keys 'size' and 'counts' containing the size and counts of the RLE encoding."
    )
    confidence: float = Field(description="Masks confidences")
    format: Optional[str] = Field(
        default="polygon", description="Format of the mask data: 'polygon' or 'rle'"
    )

inference.core.entities.responses.sam3_3d

Classes

Sam3_3D_Object_Item

Bases: BaseModel

Individual 3D object output with mesh, gaussian, and transformation metadata.

Source code in inference/core/entities/responses/sam3_3d.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Sam3_3D_Object_Item(BaseModel):
    """Individual 3D object output with mesh, gaussian, and transformation metadata."""

    mesh_glb: Optional[bytes] = Field(
        default=None, description="The 3D mesh in GLB format (binary)"
    )
    gaussian_ply: Optional[bytes] = Field(
        default=None, description="The Gaussian splatting in PLY format (binary)"
    )
    metadata: Sam3_3D_Objects_Metadata = Field(
        default_factory=Sam3_3D_Objects_Metadata,
        description="3D transformation metadata (rotation, translation, scale)",
    )

    class Config:
        arbitrary_types_allowed = True

inference.core.entities.responses.server_state

Classes

ServerVersionInfo

Bases: BaseModel

Server version information.

Attributes:

Name Type Description
name str

Server name.

version str

Server version.

uuid str

Server UUID.

Source code in inference/core/entities/responses/server_state.py
 8
 9
10
11
12
13
14
15
16
17
18
19
class ServerVersionInfo(BaseModel):
    """Server version information.

    Attributes:
        name (str): Server name.
        version (str): Server version.
        uuid (str): Server UUID.
    """

    name: str = Field(examples=["Roboflow Inference Server"])
    version: str = Field(examples=["0.0.1"])
    uuid: str = Field(examples=["9c18c6f4-2266-41fb-8a0f-c12ae28f6fbe"])

core

Core framework internals: environment config, data entities, and shared utilities.

inference.core.exceptions

Classes

ContentTypeInvalid

Bases: Exception

Raised when the content type is invalid.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
4
5
6
7
8
9
class ContentTypeInvalid(Exception):
    """Raised when the content type is invalid.

    Attributes:
        message (str): Optional message describing the error.
    """

ContentTypeMissing

Bases: Exception

Raised when the content type is missing.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
12
13
14
15
16
17
class ContentTypeMissing(Exception):
    """Raised when the content type is missing.

    Attributes:
        message (str): Optional message describing the error.
    """

EngineIgnitionFailure

Bases: Exception

Raised when the engine fails to ignite.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
20
21
22
23
24
25
class EngineIgnitionFailure(Exception):
    """Raised when the engine fails to ignite.

    Attributes:
        message (str): Optional message describing the error.
    """

InferenceModelNotFound

Bases: Exception

Raised when the inference model is not found.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
28
29
30
31
32
33
class InferenceModelNotFound(Exception):
    """Raised when the inference model is not found.

    Attributes:
        message (str): Optional message describing the error.
    """

InvalidEnvironmentVariableError

Bases: Exception

Raised when an environment variable is invalid.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
36
37
38
39
40
41
class InvalidEnvironmentVariableError(Exception):
    """Raised when an environment variable is invalid.

    Attributes:
        message (str): Optional message describing the error.
    """

InvalidMaskDecodeArgument

Bases: Exception

Raised when an invalid argument is provided for mask decoding.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
44
45
46
47
48
49
class InvalidMaskDecodeArgument(Exception):
    """Raised when an invalid argument is provided for mask decoding.

    Attributes:
        message (str): Optional message describing the error.
    """

InvalidNumpyInput

Bases: InputImageLoadError

Raised when the input is an invalid NumPy array.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
94
95
96
97
98
99
class InvalidNumpyInput(InputImageLoadError):
    """Raised when the input is an invalid NumPy array.

    Attributes:
        message (str): Optional message describing the error.
    """

MissingApiKeyError

Bases: Exception

Raised when the API key is missing.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
52
53
54
55
56
57
class MissingApiKeyError(Exception):
    """Raised when the API key is missing.

    Attributes:
        message (str): Optional message describing the error.
    """

MissingServiceSecretError

Bases: Exception

Raised when the service secret is missing.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
60
61
62
63
64
65
class MissingServiceSecretError(Exception):
    """Raised when the service secret is missing.

    Attributes:
        message (str): Optional message describing the error.
    """

OnnxProviderNotAvailable

Bases: Exception

Raised when the ONNX provider is not available.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
68
69
70
71
72
73
class OnnxProviderNotAvailable(Exception):
    """Raised when the ONNX provider is not available.

    Attributes:
        message (str): Optional message describing the error.
    """

WorkspaceLoadError

Bases: Exception

Raised when there is an error loading the workspace.

Attributes:

Name Type Description
message str

Optional message describing the error.

Source code in inference/core/exceptions.py
76
77
78
79
80
81
class WorkspaceLoadError(Exception):
    """Raised when there is an error loading the workspace.

    Attributes:
        message (str): Optional message describing the error.
    """

inference.core.nms

Functions

non_max_suppression_fast

non_max_suppression_fast(boxes, overlapThresh)

Applies non-maximum suppression to bounding boxes.

Parameters:

Name Type Description Default
boxes ndarray

Array of bounding boxes with confidence scores.

required
overlapThresh float

Overlap threshold for suppression.

required

Returns:

Name Type Description
list

List of bounding boxes after non-maximum suppression.

Source code in inference/core/nms.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def non_max_suppression_fast(boxes, overlapThresh):
    """Applies non-maximum suppression to bounding boxes.

    Args:
        boxes (np.ndarray): Array of bounding boxes with confidence scores.
        overlapThresh (float): Overlap threshold for suppression.

    Returns:
        list: List of bounding boxes after non-maximum suppression.
    """
    # if there are no boxes, return an empty list
    if len(boxes) == 0:
        return []
    # if the bounding boxes integers, convert them to floats --
    # this is important since we'll be doing a bunch of divisions
    if boxes.dtype.kind == "i":
        boxes = boxes.astype("float")
    # initialize the list of picked indexes
    pick = []
    # grab the coordinates of the bounding boxes
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    conf = boxes[:, 4]
    # compute the area of the bounding boxes and sort the bounding
    # boxes by the bottom-right y-coordinate of the bounding box
    area = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = np.argsort(conf)
    # keep looping while some indexes still remain in the indexes
    # list
    while len(idxs) > 0:
        # grab the last index in the indexes list and add the
        # index value to the list of picked indexes
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)
        # find the largest (x, y) coordinates for the start of
        # the bounding box and the smallest (x, y) coordinates
        # for the end of the bounding box
        xx1 = np.maximum(x1[i], x1[idxs[:last]])
        yy1 = np.maximum(y1[i], y1[idxs[:last]])
        xx2 = np.minimum(x2[i], x2[idxs[:last]])
        yy2 = np.minimum(y2[i], y2[idxs[:last]])
        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        # compute the ratio of overlap
        overlap = (w * h) / area[idxs[:last]]
        # delete all indexes from the index list that have
        idxs = np.delete(
            idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
        )
    # return only the bounding boxes that were picked using the
    # integer data type
    return boxes[pick].astype("float")

w_np_non_max_suppression

w_np_non_max_suppression(
    prediction,
    conf_thresh=0.25,
    iou_thresh=0.45,
    class_agnostic=False,
    max_detections=300,
    max_candidate_detections=3000,
    timeout_seconds=None,
    num_masks=0,
    box_format="xywh",
)

Applies non-maximum suppression to predictions.

Parameters:

Name Type Description Default
prediction ndarray

Array of predictions. Format for single prediction is [bbox x 4, max_class_confidence, (confidence) x num_of_classes, additional_element x num_masks]

required
conf_thresh float

Confidence threshold. Defaults to 0.25.

0.25
iou_thresh float

IOU threshold. Defaults to 0.45.

0.45
class_agnostic bool

Whether to ignore class labels. Defaults to False.

False
max_detections int

Maximum number of detections. Defaults to 300.

300
max_candidate_detections int

Maximum number of candidate detections. Defaults to 3000.

3000
timeout_seconds Optional[int]

Timeout in seconds. Defaults to None.

None
num_masks int

Number of masks. Defaults to 0.

0
box_format str

Format of bounding boxes. Either 'xywh' or 'xyxy'. Defaults to 'xywh'.

'xywh'

Returns:

Name Type Description
list

List of filtered predictions after non-maximum suppression. Format of a single result is: [bbox x 4, max_class_confidence, max_class_confidence, id_of_class_with_max_confidence, additional_element x num_masks]

Source code in inference/core/nms.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def w_np_non_max_suppression(
    prediction,
    conf_thresh: float = 0.25,
    iou_thresh: float = 0.45,
    class_agnostic: bool = False,
    max_detections: int = 300,
    max_candidate_detections: int = 3000,
    timeout_seconds: Optional[int] = None,
    num_masks: int = 0,
    box_format: str = "xywh",
):
    """Applies non-maximum suppression to predictions.

    Args:
        prediction (np.ndarray): Array of predictions. Format for single prediction is
            [bbox x 4, max_class_confidence, (confidence) x num_of_classes, additional_element x num_masks]
        conf_thresh (float, optional): Confidence threshold. Defaults to 0.25.
        iou_thresh (float, optional): IOU threshold. Defaults to 0.45.
        class_agnostic (bool, optional): Whether to ignore class labels. Defaults to False.
        max_detections (int, optional): Maximum number of detections. Defaults to 300.
        max_candidate_detections (int, optional): Maximum number of candidate detections. Defaults to 3000.
        timeout_seconds (Optional[int], optional): Timeout in seconds. Defaults to None.
        num_masks (int, optional): Number of masks. Defaults to 0.
        box_format (str, optional): Format of bounding boxes. Either 'xywh' or 'xyxy'. Defaults to 'xywh'.

    Returns:
        list: List of filtered predictions after non-maximum suppression. Format of a single result is:
            [bbox x 4, max_class_confidence, max_class_confidence, id_of_class_with_max_confidence,
            additional_element x num_masks]
    """
    num_classes = prediction.shape[2] - 5 - num_masks

    if box_format == "xywh":
        pred_view = prediction[:, :, :4]

        # Calculate all values without allocating a new array
        x1 = pred_view[:, :, 0] - pred_view[:, :, 2] / 2
        y1 = pred_view[:, :, 1] - pred_view[:, :, 3] / 2
        x2 = pred_view[:, :, 0] + pred_view[:, :, 2] / 2
        y2 = pred_view[:, :, 1] + pred_view[:, :, 3] / 2

        # Assign directly to the view
        pred_view[:, :, 0] = x1
        pred_view[:, :, 1] = y1
        pred_view[:, :, 2] = x2
        pred_view[:, :, 3] = y2
    elif box_format != "xyxy":
        raise ValueError(
            "box_format must be either 'xywh' or 'xyxy', got {}".format(box_format)
        )

    batch_predictions = []

    # Pre-allocate space for class confidence and class prediction arrays
    cls_confs_shape = (prediction.shape[1], 1)

    for np_image_i, np_image_pred in enumerate(prediction):
        np_conf_mask = np_image_pred[:, 4] >= conf_thresh
        if not np.any(np_conf_mask):  # Quick check if no boxes pass threshold
            batch_predictions.append([])
            continue

        np_image_pred = np_image_pred[np_conf_mask]

        # Handle empty case after filtering
        if np_image_pred.shape[0] == 0:
            batch_predictions.append([])
            continue

        cls_confs = np_image_pred[:, 5 : num_classes + 5]
        # Check for empty classes after slicing
        if cls_confs.shape[1] == 0:
            batch_predictions.append([])
            continue

        np_class_conf = np.max(cls_confs, axis=1, keepdims=True)
        np_class_pred = np.argmax(cls_confs, axis=1, keepdims=True)
        # Extract mask predictions if any
        if num_masks > 0:
            np_mask_pred = np_image_pred[:, 5 + num_classes :]
            # Construct final detections array directly
            np_detections = np.concatenate(
                [
                    np_image_pred[:, :5],
                    np_class_conf,
                    np_class_pred.astype(np.float32),
                    np_mask_pred,
                ],
                axis=1,
            )
        else:
            # Optimization: Avoid concatenation when no masks are present
            np_detections = np.concatenate(
                [np_image_pred[:, :5], np_class_conf, np_class_pred.astype(np.float32)],
                axis=1,
            )
        filtered_predictions = []
        if class_agnostic:
            # Sort by confidence directly
            sorted_indices = np.argsort(-np_detections[:, 4])
            np_detections_sorted = np_detections[sorted_indices]
            # Directly pass to optimized NMS
            filtered_predictions.extend(
                non_max_suppression_fast(np_detections_sorted, iou_thresh)
            )
        else:
            np_unique_labels = np.unique(np_class_pred)

            # Process each class
            for c in np_unique_labels:
                class_mask = np.atleast_1d(np_class_pred.squeeze() == c)
                np_detections_class = np_detections[class_mask]

                # Skip empty arrays
                if np_detections_class.shape[0] == 0:
                    continue

                # Sort by confidence (highest first)
                sorted_indices = np.argsort(-np_detections_class[:, 4])
                np_detections_sorted = np_detections_class[sorted_indices]

                # Apply optimized NMS and extend filtered predictions
                filtered_predictions.extend(
                    non_max_suppression_fast(np_detections_sorted, iou_thresh)
                )

        # Sort final predictions by confidence and limit to max_detections
        if filtered_predictions:
            # Use numpy sort for better performance
            filtered_np = np.array(filtered_predictions)
            idx = np.argsort(-filtered_np[:, 4])
            filtered_np = filtered_np[idx]

            # Limit to max_detections
            if len(filtered_np) > max_detections:
                filtered_np = filtered_np[:max_detections]

            batch_predictions.append(list(filtered_np))
        else:
            batch_predictions.append([])

    return batch_predictions

inference.core.roboflow_api

Classes

Functions

post_to_roboflow_api

post_to_roboflow_api(
    endpoint,
    api_key,
    payload=None,
    params=None,
    http_errors_handlers=None,
)

Generic function to make a POST request to the Roboflow API.

Parameters:

Name Type Description Default
endpoint str

API endpoint path

required
api_key Optional[str]

Roboflow API key

required
payload Optional[dict]

JSON payload

None
params Optional[List[Tuple[str, str]]]

Additional URL parameters

None
http_errors_handlers Optional[Dict[int, Callable[[Union[HTTPError]], None]]]

Optional custom HTTP error handlers by status code

None
Source code in inference/core/roboflow_api.py
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
def post_to_roboflow_api(
    endpoint: str,
    api_key: Optional[str],
    payload: Optional[dict] = None,
    params: Optional[List[Tuple[str, str]]] = None,
    http_errors_handlers: Optional[
        Dict[int, Callable[[Union[requests.exceptions.HTTPError]], None]]
    ] = None,
) -> dict:
    """Generic function to make a POST request to the Roboflow API.

    Args:
        endpoint: API endpoint path
        api_key: Roboflow API key
        payload: JSON payload
        params: Additional URL parameters
        http_errors_handlers: Optional custom HTTP error handlers by status code
    """

    @wrap_roboflow_api_errors(http_errors_handlers=http_errors_handlers)
    def _make_request():
        url_params = []
        if api_key:
            url_params.append(("api_key", api_key))
        if params:
            url_params.extend(params)

        full_url = _add_params_to_url(
            url=f"{API_BASE_URL}/{endpoint.strip('/')}", params=url_params
        )
        wrapped_url = wrap_url(full_url)

        headers = build_roboflow_api_headers()

        response = requests.post(
            url=wrapped_url,
            json=payload,
            headers=headers,
            timeout=ROBOFLOW_API_REQUEST_TIMEOUT,
            verify=ROBOFLOW_API_VERIFY_SSL,
        )
        api_key_safe_raise_for_status(response=response)
        return response.json()

    return _make_request()

inference.core.usage

Functions

trackUsage

trackUsage(endpoint, actor, n=1)

Tracks the usage of an endpoint by an actor.

This function increments the usage count for a given endpoint by an actor. It also handles initialization if the count does not exist.

Parameters:

Name Type Description Default
endpoint str

The endpoint being accessed.

required
actor str

The actor accessing the endpoint.

required
n int

The number of times the endpoint was accessed. Defaults to 1.

1

Returns:

Name Type Description
None

This function does not return anything but updates the memcache client.

Source code in inference/core/usage.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def trackUsage(endpoint, actor, n=1):
    """Tracks the usage of an endpoint by an actor.

    This function increments the usage count for a given endpoint by an actor.
    It also handles initialization if the count does not exist.

    Args:
        endpoint (str): The endpoint being accessed.
        actor (str): The actor accessing the endpoint.
        n (int, optional): The number of times the endpoint was accessed. Defaults to 1.

    Returns:
        None: This function does not return anything but updates the memcache client.
    """
    # count an inference
    try:
        job = endpoint + "endpoint:::actor" + actor
        current_infers = memcache_client.incr(job, n)
        if current_infers is None:  # not yet set; initialize at 1
            memcache_client.set(job, n)
            current_infers = n

            # store key
            job_keys = memcache_client.get("JOB_KEYS")
            if job_keys is None:
                memcache_client.add("JOB_KEYS", json.dumps([job]))
            else:
                decoded = json.loads(job_keys)
                decoded.append(job)
                decoded = list(set(decoded))
                memcache_client.set("JOB_KEYS", json.dumps(decoded))

            actor_keys = memcache_client.get("ACTOR_KEYS")
            if actor_keys is None:
                ak = {}
                ak[actor] = n
                memcache_client.add("ACTOR_KEYS", json.dumps(ak))
            else:
                decoded = json.loads(actor_keys)
                if actor in actor_keys:
                    actor_keys[actor] += n
                else:
                    actor_keys[actor] = n
                memcache_client.set("ACTOR_KEYS", json.dumps(actor_keys))

    except Exception as e:
        logger.debug("WARNING: there was an error in counting this inference")
        logger.debug(e)

core/interfaces

High-level inference interfaces: camera, HTTP, and stream processing.

inference.core.interfaces.base

Classes

BaseInterface

Base interface class which accepts a model manager on initialization

Source code in inference/core/interfaces/base.py
4
5
6
7
8
class BaseInterface:
    """Base interface class which accepts a model manager on initialization"""

    def __init__(self, model_manager: ModelManager) -> None:
        self.model_manager = model_manager

core/interfaces/camera

inference.core.interfaces.camera.camera

Classes

WebcamStream

Class to handle webcam streaming using a separate thread.

Attributes:

Name Type Description
stream_id int

The ID of the webcam stream.

frame_id int

A counter for the current frame.

vcap VideoCapture

OpenCV video capture object.

width int

The width of the video frame.

height int

The height of the video frame.

fps_input_stream int

Frames per second of the input stream.

grabbed bool

A flag indicating if a frame was successfully grabbed.

frame array

The current frame as a NumPy array.

pil_image Image

The current frame as a PIL image.

stopped bool

A flag indicating if the stream is stopped.

t Thread

The thread used to update the stream.

Source code in inference/core/interfaces/camera/camera.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class WebcamStream:
    """Class to handle webcam streaming using a separate thread.

    Attributes:
        stream_id (int): The ID of the webcam stream.
        frame_id (int): A counter for the current frame.
        vcap (VideoCapture): OpenCV video capture object.
        width (int): The width of the video frame.
        height (int): The height of the video frame.
        fps_input_stream (int): Frames per second of the input stream.
        grabbed (bool): A flag indicating if a frame was successfully grabbed.
        frame (array): The current frame as a NumPy array.
        pil_image (Image): The current frame as a PIL image.
        stopped (bool): A flag indicating if the stream is stopped.
        t (Thread): The thread used to update the stream.
    """

    def __init__(self, stream_id=0, enforce_fps=False):
        """Initialize the webcam stream.

        Args:
            stream_id (int, optional): The ID of the webcam stream. Defaults to 0.
        """
        self.stream_id = stream_id
        self.enforce_fps = enforce_fps
        self.frame_id = 0
        self.vcap = cv2.VideoCapture(self.stream_id)

        for key in os.environ:
            if key.startswith("CV2_CAP_PROP"):
                opencv_prop = key[4:]
                opencv_constant = getattr(cv2, opencv_prop, None)
                if opencv_constant is not None:
                    value = int(os.getenv(key))
                    self.vcap.set(opencv_constant, value)
                    logger.info(f"set {opencv_prop} to {value}")
                else:
                    logger.warning(f"Property {opencv_prop} not found in cv2")

        self.width = int(self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.file_mode = self.vcap.get(cv2.CAP_PROP_FRAME_COUNT) > 0
        if self.enforce_fps and not self.file_mode:
            logger.warning(
                "Ignoring enforce_fps flag for this stream. It is not compatible with streams and will cause the process to crash"
            )
            self.enforce_fps = False
        self.max_fps = None
        if self.vcap.isOpened() is False:
            logger.debug("[Exiting]: Error accessing webcam stream.")
            exit(0)
        self.fps_input_stream = int(self.vcap.get(cv2.CAP_PROP_FPS))
        logger.debug(
            "FPS of webcam hardware/input stream: {}".format(self.fps_input_stream)
        )
        self.grabbed, self.frame = self.vcap.read()
        self.pil_image = Image.fromarray(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB))
        if self.grabbed is False:
            logger.debug("[Exiting] No more frames to read")
            exit(0)
        self.stopped = True
        self.t = Thread(target=self.update, args=())
        self.t.daemon = True

    def start(self):
        """Start the thread for reading frames."""
        self.stopped = False
        self.t.start()

    def update(self):
        """Update the frame by reading from the webcam."""
        frame_id = 0
        next_frame_time = 0
        t0 = time.perf_counter()
        while True:
            t1 = time.perf_counter()
            if self.stopped is True:
                break

            self.grabbed = self.vcap.grab()
            if self.grabbed is False:
                logger.debug("[Exiting] No more frames to read")
                self.stopped = True
                break
            frame_id += 1
            # We can't retrieve each frame on nano and other lower powered devices quickly enough to keep up with the stream.
            # By default, we will only retrieve frames when we'll be ready process them (determined by self.max_fps).
            if t1 > next_frame_time:
                ret, frame = self.vcap.retrieve()
                if frame is None:
                    logger.debug("[Exiting] Frame not available for read")
                    self.stopped = True
                    break
                logger.debug(
                    f"retrieved frame {frame_id}, effective FPS: {frame_id / (t1 - t0):.2f}"
                )
                self.frame_id = frame_id
                self.frame = frame
                while self.file_mode and self.enforce_fps and self.max_fps is None:
                    # sleep until we have processed the first frame and we know what our FPS should be
                    time.sleep(0.01)
                if self.max_fps is None:
                    self.max_fps = 30
                next_frame_time = t1 + (1 / self.max_fps) + 0.02
            if self.file_mode:
                t2 = time.perf_counter()
                if self.enforce_fps:
                    # when enforce_fps is true, grab video frames 1:1 with inference speed
                    time_to_sleep = next_frame_time - t2
                else:
                    # otherwise, grab at native FPS of the video file
                    time_to_sleep = (1 / self.fps_input_stream) - (t2 - t1)
                if time_to_sleep > 0:
                    time.sleep(time_to_sleep)
        self.vcap.release()

    def read_opencv(self):
        """Read the current frame using OpenCV.

        Returns:
            array, int: The current frame as a NumPy array, and the frame ID.
        """
        return self.frame, self.frame_id

    def stop(self):
        """Stop the webcam stream."""
        self.stopped = True
Functions
__init__
__init__(stream_id=0, enforce_fps=False)

Initialize the webcam stream.

Parameters:

Name Type Description Default
stream_id int

The ID of the webcam stream. Defaults to 0.

0
Source code in inference/core/interfaces/camera/camera.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(self, stream_id=0, enforce_fps=False):
    """Initialize the webcam stream.

    Args:
        stream_id (int, optional): The ID of the webcam stream. Defaults to 0.
    """
    self.stream_id = stream_id
    self.enforce_fps = enforce_fps
    self.frame_id = 0
    self.vcap = cv2.VideoCapture(self.stream_id)

    for key in os.environ:
        if key.startswith("CV2_CAP_PROP"):
            opencv_prop = key[4:]
            opencv_constant = getattr(cv2, opencv_prop, None)
            if opencv_constant is not None:
                value = int(os.getenv(key))
                self.vcap.set(opencv_constant, value)
                logger.info(f"set {opencv_prop} to {value}")
            else:
                logger.warning(f"Property {opencv_prop} not found in cv2")

    self.width = int(self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
    self.height = int(self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    self.file_mode = self.vcap.get(cv2.CAP_PROP_FRAME_COUNT) > 0
    if self.enforce_fps and not self.file_mode:
        logger.warning(
            "Ignoring enforce_fps flag for this stream. It is not compatible with streams and will cause the process to crash"
        )
        self.enforce_fps = False
    self.max_fps = None
    if self.vcap.isOpened() is False:
        logger.debug("[Exiting]: Error accessing webcam stream.")
        exit(0)
    self.fps_input_stream = int(self.vcap.get(cv2.CAP_PROP_FPS))
    logger.debug(
        "FPS of webcam hardware/input stream: {}".format(self.fps_input_stream)
    )
    self.grabbed, self.frame = self.vcap.read()
    self.pil_image = Image.fromarray(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB))
    if self.grabbed is False:
        logger.debug("[Exiting] No more frames to read")
        exit(0)
    self.stopped = True
    self.t = Thread(target=self.update, args=())
    self.t.daemon = True
read_opencv
read_opencv()

Read the current frame using OpenCV.

Returns:

Type Description

array, int: The current frame as a NumPy array, and the frame ID.

Source code in inference/core/interfaces/camera/camera.py
127
128
129
130
131
132
133
def read_opencv(self):
    """Read the current frame using OpenCV.

    Returns:
        array, int: The current frame as a NumPy array, and the frame ID.
    """
    return self.frame, self.frame_id
start
start()

Start the thread for reading frames.

Source code in inference/core/interfaces/camera/camera.py
75
76
77
78
def start(self):
    """Start the thread for reading frames."""
    self.stopped = False
    self.t.start()
stop
stop()

Stop the webcam stream.

Source code in inference/core/interfaces/camera/camera.py
135
136
137
def stop(self):
    """Stop the webcam stream."""
    self.stopped = True
update
update()

Update the frame by reading from the webcam.

Source code in inference/core/interfaces/camera/camera.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def update(self):
    """Update the frame by reading from the webcam."""
    frame_id = 0
    next_frame_time = 0
    t0 = time.perf_counter()
    while True:
        t1 = time.perf_counter()
        if self.stopped is True:
            break

        self.grabbed = self.vcap.grab()
        if self.grabbed is False:
            logger.debug("[Exiting] No more frames to read")
            self.stopped = True
            break
        frame_id += 1
        # We can't retrieve each frame on nano and other lower powered devices quickly enough to keep up with the stream.
        # By default, we will only retrieve frames when we'll be ready process them (determined by self.max_fps).
        if t1 > next_frame_time:
            ret, frame = self.vcap.retrieve()
            if frame is None:
                logger.debug("[Exiting] Frame not available for read")
                self.stopped = True
                break
            logger.debug(
                f"retrieved frame {frame_id}, effective FPS: {frame_id / (t1 - t0):.2f}"
            )
            self.frame_id = frame_id
            self.frame = frame
            while self.file_mode and self.enforce_fps and self.max_fps is None:
                # sleep until we have processed the first frame and we know what our FPS should be
                time.sleep(0.01)
            if self.max_fps is None:
                self.max_fps = 30
            next_frame_time = t1 + (1 / self.max_fps) + 0.02
        if self.file_mode:
            t2 = time.perf_counter()
            if self.enforce_fps:
                # when enforce_fps is true, grab video frames 1:1 with inference speed
                time_to_sleep = next_frame_time - t2
            else:
                # otherwise, grab at native FPS of the video file
                time_to_sleep = (1 / self.fps_input_stream) - (t2 - t1)
            if time_to_sleep > 0:
                time.sleep(time_to_sleep)
    self.vcap.release()

inference.core.interfaces.camera.entities

Classes

StatusUpdate dataclass

Represents a status update event in the system.

Attributes:

Name Type Description
timestamp datetime

The timestamp when the status update was created.

severity UpdateSeverity

The severity level of the update.

event_type str

A string representing the type of the event.

payload dict

A dictionary containing data relevant to the update.

context str

A string providing additional context about the update.

Source code in inference/core/interfaces/camera/entities.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@dataclass(frozen=True)
class StatusUpdate:
    """Represents a status update event in the system.

    Attributes:
        timestamp (datetime): The timestamp when the status update was created.
        severity (UpdateSeverity): The severity level of the update.
        event_type (str): A string representing the type of the event.
        payload (dict): A dictionary containing data relevant to the update.
        context (str): A string providing additional context about the update.
    """

    timestamp: datetime
    severity: UpdateSeverity
    event_type: str
    payload: dict
    context: str

UpdateSeverity

Bases: Enum

Enumeration for defining different levels of update severity.

Attributes:

Name Type Description
DEBUG int

A debugging severity level.

INFO int

An informational severity level.

WARNING int

A warning severity level.

ERROR int

An error severity level.

Source code in inference/core/interfaces/camera/entities.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class UpdateSeverity(Enum):
    """Enumeration for defining different levels of update severity.

    Attributes:
        DEBUG (int): A debugging severity level.
        INFO (int): An informational severity level.
        WARNING (int): A warning severity level.
        ERROR (int): An error severity level.
    """

    DEBUG = logging.DEBUG
    INFO = logging.INFO
    WARNING = logging.WARNING
    ERROR = logging.ERROR

VideoFrame dataclass

Represents a single frame of video data.

Attributes:

Name Type Description
image ndarray

The image data of the frame as a NumPy array.

frame_id FrameID

A unique identifier for the frame.

frame_timestamp FrameTimestamp

The timestamp when the frame was captured.

source_id int

The index of the video_reference element which was passed to InferencePipeline for this frame (useful when multiple streams are passed to InferencePipeline).

fps Optional[float]

declared FPS of source (if possible to be acquired)

measured_fps Optional[float]

measured FPS of live stream

comes_from_video_file Optional[bool]

flag to determine if frame comes from video file

Source code in inference/core/interfaces/camera/entities.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@dataclass(frozen=True)
class VideoFrame:
    """Represents a single frame of video data.

    Attributes:
        image (np.ndarray): The image data of the frame as a NumPy array.
        frame_id (FrameID): A unique identifier for the frame.
        frame_timestamp (FrameTimestamp): The timestamp when the frame was captured.
        source_id (int): The index of the video_reference element which was passed to InferencePipeline for this frame
            (useful when multiple streams are passed to InferencePipeline).
        fps (Optional[float]): declared FPS of source (if possible to be acquired)
        measured_fps (Optional[float]): measured FPS of live stream
        comes_from_video_file (Optional[bool]): flag to determine if frame comes from video file
    """

    image: np.ndarray
    frame_id: FrameID
    frame_timestamp: FrameTimestamp
    # TODO: in next major version of inference replace `fps` with `declared_fps`
    fps: Optional[float] = None
    measured_fps: Optional[float] = None
    source_id: Optional[int] = None
    comes_from_video_file: Optional[bool] = None

inference.core.interfaces.camera.utils

Classes

RateLimiter

Implements rate upper-bound rate limiting by ensuring estimate_next_tick_delay() to be at min 1 / desired_fps, not letting the client obeying outcomes to exceed assumed rate.

Source code in inference/core/interfaces/camera/utils.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
class RateLimiter:
    """
    Implements rate upper-bound rate limiting by ensuring estimate_next_tick_delay()
    to be at min 1 / desired_fps, not letting the client obeying outcomes to exceed
    assumed rate.
    """

    def __init__(self, desired_fps: Union[float, int]):
        self._desired_fps = max(desired_fps, MINIMAL_FPS)
        self._last_tick: Optional[float] = None

    def tick(self) -> None:
        self._last_tick = time.monotonic()

    def estimate_next_action_delay(self) -> float:
        if self._last_tick is None:
            return 0.0
        desired_delay = 1 / self._desired_fps
        time_since_last_tick = time.monotonic() - self._last_tick
        return max(desired_delay - time_since_last_tick, 0.0)

VideoSourcesManager

This class should be treated as internal building block of stream multiplexer - not for external use.

Source code in inference/core/interfaces/camera/utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class VideoSourcesManager:
    """
    This class should be treated as internal building block of stream multiplexer - not for external use.
    """

    @classmethod
    def init(
        cls,
        video_sources: VideoSources,
        should_stop: Callable[[], bool],
        on_reconnection_error: Callable[[Optional[int], SourceConnectionError], None],
    ) -> "VideoSourcesManager":
        return cls(
            video_sources=video_sources,
            should_stop=should_stop,
            on_reconnection_error=on_reconnection_error,
        )

    def __init__(
        self,
        video_sources: VideoSources,
        should_stop: Callable[[], bool],
        on_reconnection_error: Callable[[Optional[int], SourceConnectionError], None],
    ):
        self._video_sources = video_sources
        self._reconnection_threads: Dict[int, Thread] = {}
        self._external_should_stop = should_stop
        self._on_reconnection_error = on_reconnection_error
        self._enforce_stop: Dict[int, bool] = {}
        self._ended_sources: Set[int] = set()
        self._threads_to_join: Set[int] = set()
        self._last_batch_yielded_time = datetime.now()

    def retrieve_frames_from_sources(
        self,
        batch_collection_timeout: Optional[float],
    ) -> Optional[List[VideoFrame]]:
        batch_frames = []
        if batch_collection_timeout is not None:
            batch_timeout_moment = self._last_batch_yielded_time + timedelta(
                seconds=batch_collection_timeout
            )
        else:
            batch_timeout_moment = None
        for source_ord, (source, source_should_reconnect) in enumerate(
            zip(self._video_sources.all_sources, self._video_sources.allow_reconnection)
        ):
            if self._external_should_stop():
                self.join_all_reconnection_threads(include_not_finished=True)
                return None
            if self._is_source_inactive(source_ord=source_ord):
                continue
            batch_time_left = (
                None
                if batch_timeout_moment is None
                else max((batch_timeout_moment - datetime.now()).total_seconds(), 0.0)
            )
            try:
                frame = source.read_frame(timeout=batch_time_left)
                if frame is not None:
                    batch_frames.append(frame)
            except EndOfStreamError:
                self._register_end_of_stream(source_ord=source_ord)
        self.join_all_reconnection_threads()
        self._last_batch_yielded_time = datetime.now()
        return batch_frames

    def all_sources_ended(self) -> bool:
        return len(self._ended_sources) >= len(self._video_sources.all_sources)

    def join_all_reconnection_threads(self, include_not_finished: bool = False) -> None:
        for source_ord in copy(self._threads_to_join):
            self._purge_reconnection_thread(source_ord=source_ord)
        if not include_not_finished:
            return None
        for source_ord in list(self._reconnection_threads.keys()):
            self._purge_reconnection_thread(source_ord=source_ord)

    def _is_source_inactive(self, source_ord: int) -> bool:
        return (
            source_ord in self._ended_sources
            or source_ord in self._reconnection_threads
        )

    def _register_end_of_stream(self, source_ord: int) -> None:
        source_should_reconnect = self._video_sources.allow_reconnection[source_ord]
        if source_should_reconnect:
            self._reconnect_source(source_ord=source_ord)
        else:
            self._ended_sources.add(source_ord)

    def _reconnect_source(self, source_ord: int) -> None:
        if source_ord in self._reconnection_threads:
            return None
        self._reconnection_threads[source_ord] = Thread(
            target=_attempt_reconnect,
            args=(
                self._video_sources.all_sources[source_ord],
                partial(self._should_stop, source_ord=source_ord),
                self._on_reconnection_error,
                partial(self._register_thread_to_join, source_ord=source_ord),
                partial(self._register_reconnection_fatal_error, source_ord=source_ord),
            ),
        )
        self._reconnection_threads[source_ord].start()

    def _register_reconnection_fatal_error(self, source_ord: int) -> None:
        self._register_thread_to_join(source_ord=source_ord)
        self._ended_sources.add(source_ord)

    def _register_thread_to_join(self, source_ord: int) -> None:
        self._threads_to_join.add(source_ord)

    def _purge_reconnection_thread(self, source_ord: int) -> None:
        if source_ord not in self._reconnection_threads:
            return None
        self._enforce_stop[source_ord] = True
        self._reconnection_threads[source_ord].join()
        del self._reconnection_threads[source_ord]
        self._enforce_stop[source_ord] = False
        if source_ord in self._threads_to_join:
            self._threads_to_join.remove(source_ord)

    def _should_stop(self, source_ord: int) -> bool:
        if self._external_should_stop():
            return True
        return self._enforce_stop.get(source_ord, False)

Functions

get_video_frames_generator

get_video_frames_generator(
    video, max_fps=None, limiter_strategy=None
)

Util function to create a frames generator from VideoSource with possibility to limit FPS of consumed frames and dictate what to do if frames are produced to fast.

Parameters:

Name Type Description Default
video Union[VideoSource, str, int]

Either instance of VideoSource or video reference accepted by VideoSource.init(...)

required
max_fps Optional[Union[float, int]]

value of maximum FPS rate of generated frames - can be used to limit generation frequency

None
limiter_strategy Optional[FPSLimiterStrategy]

strategy used to deal with frames decoding exceeding limit of max_fps. By default - for files, in the interest of processing all frames - generation will be awaited, for streams - frames will be dropped on the floor.

None
Example
from inference.core.interfaces.camera.utils import get_video_frames_generator

for frame in get_video_frames_generator(
    video="./some.mp4",
    max_fps=50,
):
     pass
Source code in inference/core/interfaces/camera/utils.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_video_frames_generator(
    video: Union[VideoSource, str, int],
    max_fps: Optional[Union[float, int]] = None,
    limiter_strategy: Optional[FPSLimiterStrategy] = None,
) -> Generator[VideoFrame, None, None]:
    """
    Util function to create a frames generator from `VideoSource` with possibility to
    limit FPS of consumed frames and dictate what to do if frames are produced to fast.

    Args:
        video (Union[VideoSource, str, int]): Either instance of VideoSource or video reference accepted
            by VideoSource.init(...)
        max_fps (Optional[Union[float, int]]): value of maximum FPS rate of generated frames - can be used to limit
            generation frequency
        limiter_strategy (Optional[FPSLimiterStrategy]): strategy used to deal with frames decoding exceeding
            limit of `max_fps`. By default - for files, in the interest of processing all frames -
            generation will be awaited, for streams - frames will be dropped on the floor.
    Returns: generator of `VideoFrame`

    Example:
        ```python
        from inference.core.interfaces.camera.utils import get_video_frames_generator

        for frame in get_video_frames_generator(
            video="./some.mp4",
            max_fps=50,
        ):
             pass
        ```
    """
    is_managed_source = False
    if issubclass(type(video), str) or issubclass(type(video), int):
        video = VideoSource.init(
            video_reference=video,
        )
        video.start()
        is_managed_source = True
    if max_fps is None:
        yield from video
        if is_managed_source:
            video.terminate(purge_frames_buffer=True)
        return None
    limiter_strategy = resolve_limiter_strategy(
        explicitly_defined_strategy=limiter_strategy,
        source_properties=video.describe_source().source_properties,
    )
    yield from limit_frame_rate(
        frames_generator=video, max_fps=max_fps, strategy=limiter_strategy
    )
    if is_managed_source:
        video.terminate(purge_frames_buffer=True)
    return None

multiplex_videos

multiplex_videos(
    videos,
    max_fps=None,
    limiter_strategy=None,
    batch_collection_timeout=None,
    force_stream_reconnection=True,
    should_stop=never_stop,
    on_reconnection_error=log_error,
)

Function that is supposed to provide a generator over frames from multiple video sources. It is capable to initialise VideoSource from references to video files or streams and grab frames from all the sources - each running individual decoding on separate thread. In each cycle it attempts to grab frames from all sources (and wait at max batch_collection_timeout for whole batch to be collected). If frame from specific source cannot be collected in that time - it is simply not included in returned list. If after batch collection list of frames is empty - new collection start immediately. Collection does not account for sources that lost connectivity (example: streams that went offline). If that does not happen and stream has large latency - without reasonable batch_collection_timeout it will slow down processing - so please set it up in PROD solutions. In case of video streams (not video files) - given that force_stream_reconnection=True function will attempt to re-connect to disconnected source using background thread, not impairing batch frames collection and that source is not going to block frames retrieval even if infinite batch_collection_timeout=None is set. Similarly, when processing files - video file that is shorter than other passed into processing will not block the whole flow after End Of Stream (EOS).

All sources must be accessible on start - if that's not the case - logic function raises SourceConnectionError and closes all video sources it opened on it own. Disconnections at later stages are handled by re-connection mechanism.

Parameters:

Name Type Description Default
videos List[Union[VideoSource, str, int]]

List with references to video sources. Elements can be pre-initialised VideoSource instances, str with stream URI or file location or int representing camera device attached to the PC/server running the code.

required
max_fps Optional[Union[float, int]]

Upper-bound of processing speed - to be used when one wants at max max_fps video frames per second to be yielded from all sources by the generator.

None
limiter_strategy Optional[FPSLimiterStrategy]

strategy used to deal with frames decoding exceeding limit of max_fps. For video files, in the interest of processing all frames - we recommend WAIT mode, for streams - frames should be dropped on the floor with DROP strategy. Not setting the strategy equals using automatic mode - WAIT if all sources are files and DROP otherwise

None
batch_collection_timeout Optional[float]

maximum await time to get batch of predictions from all sources. None means infinite timeout.

None
force_stream_reconnection bool

Flag to decide on reconnection to streams (files are never re-connected)

True
should_stop Callable[[], bool]

external stop signal that is periodically checked - to denote that video consumption stopped - make the function to return True

never_stop
on_reconnection_error Callable[[Optional[int], SourceConnectionError], None]

Function that will be called whenever source cannot re-connect after disconnection. First parameter is source_id, second is connection error instance.

log_error

Returns Generator[List[VideoFrame], None, None]: allowing to iterate through frames from multiple video sources.

Raises:

Type Description
SourceConnectionError

when one or more source is not reachable at start of generation

Example
from inference.core.interfaces.camera.utils import multiplex_videos

for frames in multiplex_videos(videos=["./some.mp4", "./other.mp4"]):
     for frame in frames:
        pass  # do something with frame
Source code in inference/core/interfaces/camera/utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def multiplex_videos(
    videos: List[Union[VideoSource, str, int]],
    max_fps: Optional[Union[float, int]] = None,
    limiter_strategy: Optional[FPSLimiterStrategy] = None,
    batch_collection_timeout: Optional[float] = None,
    force_stream_reconnection: bool = True,
    should_stop: Callable[[], bool] = never_stop,
    on_reconnection_error: Callable[
        [Optional[int], SourceConnectionError], None
    ] = log_error,
) -> Generator[List[VideoFrame], None, None]:
    """
    Function that is supposed to provide a generator over frames from multiple video sources. It is capable to
    initialise `VideoSource` from references to video files or streams and grab frames from all the sources -
    each running individual decoding on separate thread. In each cycle it attempts to grab frames from all sources
    (and wait at max `batch_collection_timeout` for whole batch to be collected). If frame from specific source
    cannot be collected in that time - it is simply not included in returned list. If after batch collection list of
    frames is empty - new collection start immediately. Collection does not account for
    sources that lost connectivity (example: streams that went offline). If that does not happen and stream has
    large latency - without reasonable `batch_collection_timeout` it will slow down processing - so please
    set it up in PROD solutions. In case of video streams (not video files) - given that
    `force_stream_reconnection=True` function will attempt to re-connect to disconnected source using background thread,
    not impairing batch frames collection and that source is not going to block frames retrieval even if infinite
    `batch_collection_timeout=None` is set. Similarly, when processing files - video file that is shorter than other
    passed into processing will not block the whole flow after End Of Stream (EOS).

    All sources must be accessible on start - if that's not the case - logic function raises `SourceConnectionError`
    and closes all video sources it opened on it own. Disconnections at later stages are handled by re-connection
    mechanism.

    Args:
        videos (List[Union[VideoSource, str, int]]): List with references to video sources. Elements can be
            pre-initialised `VideoSource` instances, str with stream URI or file location or int representing
            camera device attached to the PC/server running the code.
        max_fps (Optional[Union[float, int]]): Upper-bound of processing speed - to be used when one wants at max
            `max_fps` video frames per second to be yielded from all sources by the generator.
        limiter_strategy (Optional[FPSLimiterStrategy]): strategy used to deal with frames decoding exceeding
            limit of `max_fps`. For video files, in the interest of processing all frames - we recommend WAIT mode,
             for streams - frames should be dropped on the floor with DROP strategy. Not setting the strategy equals
             using automatic mode - WAIT if all sources are files and DROP otherwise
        batch_collection_timeout (Optional[float]): maximum await time to get batch of predictions from all sources.
            `None` means infinite timeout.
        force_stream_reconnection (bool): Flag to decide on reconnection to streams (files are never re-connected)
        should_stop (Callable[[], bool]): external stop signal that is periodically checked - to denote that
            video consumption stopped - make the function to return True
        on_reconnection_error (Callable[[Optional[int], SourceConnectionError], None]): Function that will be
            called whenever source cannot re-connect after disconnection. First parameter is source_id, second
            is connection error instance.

    Returns Generator[List[VideoFrame], None, None]: allowing to iterate through frames from multiple video sources.

    Raises:
        SourceConnectionError: when one or more source is not reachable at start of generation

    Example:
        ```python
        from inference.core.interfaces.camera.utils import multiplex_videos

        for frames in multiplex_videos(videos=["./some.mp4", "./other.mp4"]):
             for frame in frames:
                pass  # do something with frame
        ```
    """
    video_sources = _prepare_video_sources(
        videos=videos, force_stream_reconnection=force_stream_reconnection
    )
    if any(rule is None for rule in video_sources.allow_reconnection):
        logger.warning("Could not connect to all sources.")
        return None
    generator = _multiplex_videos(
        video_sources=video_sources,
        batch_collection_timeout=batch_collection_timeout,
        should_stop=should_stop,
        on_reconnection_error=on_reconnection_error,
    )
    if max_fps is None:
        yield from generator
        return None
    max_fps = max_fps / len(videos)
    if limiter_strategy is None:
        limiter_strategy = negotiate_rate_limiter_strategy_for_multiple_sources(
            video_sources=video_sources.all_sources,
        )
    yield from limit_frame_rate(
        frames_generator=generator, max_fps=max_fps, strategy=limiter_strategy
    )

inference.core.interfaces.camera.video_source

Classes

VideoConsumer

This class should be consumed as part of internal implementation. It provides abstraction around stream consumption strategies.

It must always be given the same video source for consecutive invocations, otherwise the internal state does not make sense.

Source code in inference/core/interfaces/camera/video_source.py
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
class VideoConsumer:
    """
    This class should be consumed as part of internal implementation.
    It provides abstraction around stream consumption strategies.

    It must always be given the same video source for consecutive invocations,
    otherwise the internal state does not make sense.
    """

    @classmethod
    def init(
        cls,
        buffer_filling_strategy: Optional[BufferFillingStrategy],
        adaptive_mode_stream_pace_tolerance: float,
        adaptive_mode_reader_pace_tolerance: float,
        minimum_adaptive_mode_samples: int,
        maximum_adaptive_frames_dropped_in_row: int,
        status_update_handlers: List[Callable[[StatusUpdate], None]],
        desired_fps: Optional[Union[float, int]] = None,
    ) -> "VideoConsumer":
        minimum_adaptive_mode_samples = max(minimum_adaptive_mode_samples, 2)
        reader_pace_monitor = sv.FPSMonitor(
            sample_size=10 * minimum_adaptive_mode_samples
        )
        stream_consumption_pace_monitor = sv.FPSMonitor(
            sample_size=10 * minimum_adaptive_mode_samples
        )
        decoding_pace_monitor = sv.FPSMonitor(
            sample_size=10 * minimum_adaptive_mode_samples
        )
        return cls(
            buffer_filling_strategy=buffer_filling_strategy,
            adaptive_mode_stream_pace_tolerance=adaptive_mode_stream_pace_tolerance,
            adaptive_mode_reader_pace_tolerance=adaptive_mode_reader_pace_tolerance,
            minimum_adaptive_mode_samples=minimum_adaptive_mode_samples,
            maximum_adaptive_frames_dropped_in_row=maximum_adaptive_frames_dropped_in_row,
            status_update_handlers=status_update_handlers,
            reader_pace_monitor=reader_pace_monitor,
            stream_consumption_pace_monitor=stream_consumption_pace_monitor,
            decoding_pace_monitor=decoding_pace_monitor,
            desired_fps=desired_fps,
        )

    def __init__(
        self,
        buffer_filling_strategy: Optional[BufferFillingStrategy],
        adaptive_mode_stream_pace_tolerance: float,
        adaptive_mode_reader_pace_tolerance: float,
        minimum_adaptive_mode_samples: int,
        maximum_adaptive_frames_dropped_in_row: int,
        status_update_handlers: List[Callable[[StatusUpdate], None]],
        reader_pace_monitor: sv.FPSMonitor,
        stream_consumption_pace_monitor: sv.FPSMonitor,
        decoding_pace_monitor: sv.FPSMonitor,
        desired_fps: Optional[Union[float, int]],
    ):
        self._buffer_filling_strategy = buffer_filling_strategy
        self._frame_counter = 0
        self._adaptive_mode_stream_pace_tolerance = adaptive_mode_stream_pace_tolerance
        self._adaptive_mode_reader_pace_tolerance = adaptive_mode_reader_pace_tolerance
        self._minimum_adaptive_mode_samples = minimum_adaptive_mode_samples
        self._maximum_adaptive_frames_dropped_in_row = (
            maximum_adaptive_frames_dropped_in_row
        )
        self._adaptive_frames_dropped_in_row = 0
        self._reader_pace_monitor = reader_pace_monitor
        self._stream_consumption_pace_monitor = stream_consumption_pace_monitor
        self._decoding_pace_monitor = decoding_pace_monitor
        self._desired_fps = desired_fps
        self._declared_source_fps = None
        self._is_source_video_file = None
        self._timestamp_created: Optional[datetime] = None
        self._status_update_handlers = status_update_handlers
        self._next_frame_from_video_to_accept = 1

    @property
    def buffer_filling_strategy(self) -> Optional[BufferFillingStrategy]:
        return self._buffer_filling_strategy

    def reset(self, source_properties: SourceProperties) -> None:
        if source_properties.is_file:
            self._set_file_mode_buffering_strategies()
        else:
            self._set_stream_mode_buffering_strategies()
        self._reader_pace_monitor.reset()
        self.reset_stream_consumption_pace()
        self._decoding_pace_monitor.reset()
        self._adaptive_frames_dropped_in_row = 0
        self._next_frame_from_video_to_accept = self._frame_counter + 1

    def reset_stream_consumption_pace(self) -> None:
        self._stream_consumption_pace_monitor.reset()

    def notify_frame_consumed(self) -> None:
        self._reader_pace_monitor.tick()

    def consume_frame(
        self,
        video: VideoFrameProducer,
        declared_source_fps: Optional[float],
        is_source_video_file: Optional[bool],
        buffer: Queue,
        frames_buffering_allowed: bool,
        source_id: Optional[int] = None,
    ) -> bool:
        if self._is_source_video_file is None:
            source_properties = video.discover_source_properties()
            self._is_source_video_file = source_properties.is_file
            self._declared_source_fps = source_properties.fps
            self._timestamp_created = source_properties.timestamp_created

        if self._timestamp_created:
            frame_timestamp = self._timestamp_created + timedelta(
                seconds=self._frame_counter / self._declared_source_fps
            )
        else:
            frame_timestamp = datetime.now()

        success = video.grab()
        self._stream_consumption_pace_monitor.tick()
        if not success:
            return False
        self._frame_counter += 1
        if self._status_update_handlers:
            send_video_source_status_update(
                severity=UpdateSeverity.DEBUG,
                event_type=FRAME_CAPTURED_EVENT,
                payload={
                    "frame_timestamp": frame_timestamp,
                    "frame_id": self._frame_counter,
                    "source_id": source_id,
                },
                status_update_handlers=self._status_update_handlers,
            )
        measured_source_fps = declared_source_fps
        if not is_source_video_file:
            if hasattr(self._stream_consumption_pace_monitor, "fps"):
                measured_source_fps = self._stream_consumption_pace_monitor.fps
            else:
                measured_source_fps = self._stream_consumption_pace_monitor()

        if self._video_fps_should_be_sub_sampled():
            return True
        return self._consume_stream_frame(
            video=video,
            declared_source_fps=declared_source_fps,
            measured_source_fps=measured_source_fps,
            is_source_video_file=is_source_video_file,
            frame_timestamp=frame_timestamp,
            buffer=buffer,
            frames_buffering_allowed=frames_buffering_allowed,
            source_id=source_id,
        )

    def _set_file_mode_buffering_strategies(self) -> None:
        if self._buffer_filling_strategy is None:
            self._buffer_filling_strategy = BufferFillingStrategy.WAIT

    def _set_stream_mode_buffering_strategies(self) -> None:
        if self._buffer_filling_strategy is None:
            self._buffer_filling_strategy = BufferFillingStrategy.ADAPTIVE_DROP_OLDEST

    def _video_fps_should_be_sub_sampled(self) -> bool:
        if self._desired_fps is None:
            return False
        if self._is_source_video_file:
            actual_fps = self._declared_source_fps
        else:
            fraction_of_pace_monitor_samples = (
                len(self._stream_consumption_pace_monitor.all_timestamps)
                / self._stream_consumption_pace_monitor.all_timestamps.maxlen
            )
            if fraction_of_pace_monitor_samples < 0.9:
                actual_fps = self._declared_source_fps
            elif hasattr(self._stream_consumption_pace_monitor, "fps"):
                actual_fps = self._stream_consumption_pace_monitor.fps
            else:
                actual_fps = self._stream_consumption_pace_monitor()
        if self._frame_counter == self._next_frame_from_video_to_accept:
            stride = calculate_video_file_stride(
                actual_fps=actual_fps,
                desired_fps=self._desired_fps,
            )
            self._next_frame_from_video_to_accept += stride
            return False
        # skipping frame
        return True

    def _consume_stream_frame(
        self,
        video: VideoFrameProducer,
        declared_source_fps: Optional[float],
        measured_source_fps: Optional[float],
        is_source_video_file: Optional[bool],
        frame_timestamp: datetime,
        buffer: Queue,
        frames_buffering_allowed: bool,
        source_id: Optional[int],
    ) -> bool:
        """
        Returns: boolean flag with success status
        """
        if not frames_buffering_allowed:
            send_frame_drop_update(
                frame_timestamp=frame_timestamp,
                frame_id=self._frame_counter,
                cause="Buffering not allowed at the moment",
                status_update_handlers=self._status_update_handlers,
                source_id=source_id,
            )
            return True
        if self._frame_should_be_adaptively_dropped(
            declared_source_fps=declared_source_fps
        ):
            self._adaptive_frames_dropped_in_row += 1
            send_frame_drop_update(
                frame_timestamp=frame_timestamp,
                frame_id=self._frame_counter,
                cause="ADAPTIVE strategy",
                status_update_handlers=self._status_update_handlers,
                source_id=source_id,
            )
            return True
        self._adaptive_frames_dropped_in_row = 0
        if (
            not buffer.full()
            or self._buffer_filling_strategy is BufferFillingStrategy.WAIT
        ):
            return decode_video_frame_to_buffer(
                frame_timestamp=frame_timestamp,
                frame_id=self._frame_counter,
                video=video,
                buffer=buffer,
                decoding_pace_monitor=self._decoding_pace_monitor,
                source_id=source_id,
                declared_source_fps=declared_source_fps,
                measured_source_fps=measured_source_fps,
                comes_from_video_file=is_source_video_file,
            )
        if self._buffer_filling_strategy in DROP_OLDEST_STRATEGIES:
            return self._process_stream_frame_dropping_oldest(
                frame_timestamp=frame_timestamp,
                video=video,
                buffer=buffer,
                source_id=source_id,
                is_video_file=is_source_video_file,
            )
        send_frame_drop_update(
            frame_timestamp=frame_timestamp,
            frame_id=self._frame_counter,
            cause="DROP_LATEST strategy",
            status_update_handlers=self._status_update_handlers,
            source_id=source_id,
        )
        return True

    def _frame_should_be_adaptively_dropped(
        self, declared_source_fps: Optional[float]
    ) -> bool:
        if self._buffer_filling_strategy not in ADAPTIVE_STRATEGIES:
            return False
        if (
            self._adaptive_frames_dropped_in_row
            >= self._maximum_adaptive_frames_dropped_in_row
        ):
            return False
        if (
            len(self._stream_consumption_pace_monitor.all_timestamps)
            <= self._minimum_adaptive_mode_samples
        ):
            # not enough observations
            return False
        if hasattr(self._stream_consumption_pace_monitor, "fps"):
            stream_consumption_pace = self._stream_consumption_pace_monitor.fps
        else:
            stream_consumption_pace = self._stream_consumption_pace_monitor()
        announced_stream_fps = stream_consumption_pace
        if declared_source_fps is not None and declared_source_fps > 0:
            announced_stream_fps = declared_source_fps
        if (
            announced_stream_fps - stream_consumption_pace
            > self._adaptive_mode_stream_pace_tolerance
        ):
            # cannot keep up with stream emission
            return True
        if (
            len(self._reader_pace_monitor.all_timestamps)
            <= self._minimum_adaptive_mode_samples
        ) or (
            len(self._decoding_pace_monitor.all_timestamps)
            <= self._minimum_adaptive_mode_samples
        ):
            # not enough observations
            return False
        actual_reader_pace = get_fps_if_tick_happens_now(
            fps_monitor=self._reader_pace_monitor
        )
        if hasattr(self._decoding_pace_monitor, "fps"):
            decoding_pace = self._decoding_pace_monitor.fps
        else:
            decoding_pace = self._decoding_pace_monitor()
        if (
            decoding_pace - actual_reader_pace
            > self._adaptive_mode_reader_pace_tolerance
        ):
            # we are too fast for the reader - time to save compute on decoding
            return True
        return False

    def _process_stream_frame_dropping_oldest(
        self,
        frame_timestamp: datetime,
        video: VideoFrameProducer,
        buffer: Queue,
        source_id: Optional[int],
        is_video_file: bool,
    ) -> bool:
        drop_single_frame_from_buffer(
            buffer=buffer,
            cause="DROP_OLDEST strategy",
            status_update_handlers=self._status_update_handlers,
        )
        return decode_video_frame_to_buffer(
            frame_timestamp=frame_timestamp,
            frame_id=self._frame_counter,
            video=video,
            buffer=buffer,
            decoding_pace_monitor=self._decoding_pace_monitor,
            source_id=source_id,
            comes_from_video_file=is_video_file,
        )

VideoSource

Source code in inference/core/interfaces/camera/video_source.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
class VideoSource:
    @classmethod
    def init(
        cls,
        video_reference: VideoSourceIdentifier,
        buffer_size: int = DEFAULT_BUFFER_SIZE,
        status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None,
        buffer_filling_strategy: Optional[BufferFillingStrategy] = None,
        buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None,
        adaptive_mode_stream_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE,
        adaptive_mode_reader_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE,
        minimum_adaptive_mode_samples: int = DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES,
        maximum_adaptive_frames_dropped_in_row: int = DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW,
        video_source_properties: Optional[Dict[str, float]] = None,
        source_id: Optional[int] = None,
        desired_fps: Optional[Union[float, int]] = None,
    ):
        """
        This class is meant to represent abstraction over video sources - both video files and
        on-line streams that are possible to be consumed and used by other components of `inference`
        library.

        Before digging into details of the class behaviour, it is advised to familiarise with the following
        concepts and implementation assumptions:

        1. Video file can be accessed from local (or remote) storage by the consumer in a pace dictated by
            its processing capabilities. If processing is faster than the frame rate of video, operations
            may be executed in a time shorter than the time of video playback. In the opposite case - consumer
            may freely decode and process frames in its own pace, without risk for failures due to temporal
            dependencies of processing - this is classical offline processing example.
        2. Video streams, on the other hand, usually need to be consumed in a pace near to their frame-rate -
            in other words - this is on-line processing example. Consumer being faster than incoming stream
            frames cannot utilise its resources to the full extent as not-yet-delivered data would be needed.
            Slow consumer, however, may not be able to process everything on time and to keep up with the pace
            of stream - some frames would need to be dropped. Otherwise - over time, consumer could go out of
            sync with the stream causing decoding failures or unpredictable behavior.

        To fit those two types of video sources, `VideoSource` introduces the concept of buffered decoding of
        video stream (like at the YouTube - player buffers some frames that are soon to be displayed).
        The way on how buffer is filled and consumed dictates the behavior of `VideoSource`.

        Starting from `BufferFillingStrategy` - we have 3 basic options:
        * WAIT: in case of slow video consumption, when buffer is full - `VideoSource` will wait for
        the empty spot in buffer before next frame will be processed - this is suitable in cases when
        we want to ensure EACH FRAME of the video to be processed
        * DROP_OLDEST: when buffer is full, the frame that sits there for the longest time will be dropped -
        this is suitable for cases when we want to process the most recent frames possible
        * DROP_LATEST: when buffer is full, the newly decoded frame is dropped - useful in cases when
        it is expected to have processing performance drops, but we would like to consume portions of
        video that are locally smooth - but this is probably the least common use-case.

        On top of that - there are two ADAPTIVE strategies: ADAPTIVE_DROP_OLDEST and ADAPTIVE_DROP_LATEST,
        which are equivalent to DROP_OLDEST and DROP_LATEST with adaptive decoding feature enabled. The notion
        of that mode will be described later.

        Naturally, decoded frames must also be consumed. `VideoSource` provides a handy interface for reading
        a video source frames by a SINGLE consumer. Consumption strategy can also be dictated via
        `BufferConsumptionStrategy`:
        * LAZY - consume all the frames from decoding buffer one-by-one
        * EAGER - at each readout - take all frames already buffered, drop all of them apart from the most recent

        In consequence - there are various combinations of `BufferFillingStrategy` and `BufferConsumptionStrategy`.
        The most popular would be:
        * `BufferFillingStrategy.WAIT` and `BufferConsumptionStrategy.LAZY` - to always decode and process each and
            every frame of the source (useful while processing video files - and default behaviour enforced by
            `inference` if there is no explicit configuration)
        * `BufferFillingStrategy.DROP_OLDEST` and `BufferConsumptionStrategy.EAGER` - to always process the most
            recent frames of source (useful while processing video streams when low latency [real-time experience]
            is required - ADAPTIVE version of this is default for streams)

        ADAPTIVE strategies were introduced to handle corner-cases, when consumer hardware is not capable to consume
        video stream and process frames at the same time (for instance - Nvidia Jetson devices running processing
        against hi-res streams with high FPS ratio). It acts with buffer in nearly the same way as `DROP_OLDEST`
        and `DROP_LATEST` strategies, but there are two more conditions that may influence frame drop:
        * announced rate of source - which in fact dictate the pace of frames grabbing from incoming stream that
        MUST be met by consumer to avoid strange decoding issues causing decoder to fail - if the pace of frame grabbing
        deviates too much - decoding will be postponed, and frames dropped to grab next ones sooner
        * consumption rate - in resource constraints environment, not only decoding is problematic from the performance
        perspective - but also heavy processing. If consumer is not quick enough - allocating more useful resources
        for decoding frames that may never be processed is a waste. That's why - if decoding happens more frequently
        than consumption of frame - ADAPTIVE mode causes decoding to be done in a slower pace and more frames are just
        grabbed and dropped on the floor.
        ADAPTIVE mode increases latency slightly, but may be the only way to operate in some cases.
        Behaviour of adaptive mode, including the maximum acceptable deviations of frames grabbing pace from source,
        reader pace and maximum number of consecutive frames dropped in ADAPTIVE mode are configurable by clients,
        with reasonable defaults being set.

        `VideoSource` emits events regarding its activity - which can be intercepted by custom handlers. Take
        into account that they are always executed in context of thread invoking them (and should be fast to complete,
        otherwise may block the flow of stream consumption). All errors raised will be emitted as logger warnings only.

        `VideoSource` implementation is naturally multithreading, with different thread decoding video and different
        one consuming it and manipulating source state. Implementation of user interface is thread-safe, although
        stream it is meant to be consumed by a single thread only.

        ENV variables involved:
        * VIDEO_SOURCE_BUFFER_SIZE - default: 64
        * VIDEO_SOURCE_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE - default: 0.1
        * VIDEO_SOURCE_ADAPTIVE_MODE_READER_PACE_TOLERANCE - default: 5.0
        * VIDEO_SOURCE_MINIMUM_ADAPTIVE_MODE_SAMPLES - default: 10
        * VIDEO_SOURCE_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW - default: 16

        As an `inference` user, please use .init() method instead of constructor to instantiate objects.

        Args:
            video_reference (Union[str, int]): Either str with file or stream reference, or int representing device ID
            buffer_size (int): size of decoding buffer
            status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers for status updates
            buffer_filling_strategy (Optional[BufferFillingStrategy]): Settings for buffer filling strategy - if not
                given - automatic choice regarding source type will be applied
            buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Settings for buffer consumption strategy,
                if not given - automatic choice regarding source type will be applied
            adaptive_mode_stream_pace_tolerance (float): Maximum deviation between frames grabbing pace and stream pace
                that will not trigger adaptive mode frame drop
            adaptive_mode_reader_pace_tolerance (float): Maximum deviation between decoding pace and stream consumption
                pace that will not trigger adaptive mode frame drop
            minimum_adaptive_mode_samples (int): Minimal number of frames to be used to establish actual pace of
                processing, before adaptive mode can drop any frame
            maximum_adaptive_frames_dropped_in_row (int): Maximum number of frames dropped in row due to application of
                adaptive strategy
            video_source_properties (Optional[dict[str, float]]): Optional dictionary with video source properties
                corresponding to OpenCV VideoCapture properties cv2.CAP_PROP_* to set values for the video source.
            source_id (Optional[int]): Optional identifier of video source - mainly useful to recognise specific source
                when multiple ones are in use. Identifier will be added to emitted frames and updates. It is advised
                to keep it unique within all sources in use.

        Returns: Instance of `VideoSource` class
        """
        frames_buffer = Queue(maxsize=buffer_size)
        if status_update_handlers is None:
            status_update_handlers = []
        video_consumer = VideoConsumer.init(
            buffer_filling_strategy=buffer_filling_strategy,
            adaptive_mode_stream_pace_tolerance=adaptive_mode_stream_pace_tolerance,
            adaptive_mode_reader_pace_tolerance=adaptive_mode_reader_pace_tolerance,
            minimum_adaptive_mode_samples=minimum_adaptive_mode_samples,
            maximum_adaptive_frames_dropped_in_row=maximum_adaptive_frames_dropped_in_row,
            status_update_handlers=status_update_handlers,
            desired_fps=desired_fps,
        )
        return cls(
            stream_reference=video_reference,
            frames_buffer=frames_buffer,
            status_update_handlers=status_update_handlers,
            buffer_consumption_strategy=buffer_consumption_strategy,
            video_consumer=video_consumer,
            video_source_properties=video_source_properties,
            source_id=source_id,
        )

    def __init__(
        self,
        stream_reference: VideoSourceIdentifier,
        frames_buffer: Queue,
        status_update_handlers: List[Callable[[StatusUpdate], None]],
        buffer_consumption_strategy: Optional[BufferConsumptionStrategy],
        video_consumer: "VideoConsumer",
        video_source_properties: Optional[Dict[str, float]],
        source_id: Optional[int],
    ):
        self._stream_reference = stream_reference
        self._video: Optional[VideoFrameProducer] = None
        self._source_properties: Optional[SourceProperties] = None
        self._frames_buffer = frames_buffer
        self._status_update_handlers = status_update_handlers
        self._buffer_consumption_strategy = buffer_consumption_strategy
        self._video_consumer = video_consumer
        self._state = StreamState.NOT_STARTED
        self._playback_allowed = Event()
        self._frames_buffering_allowed = True
        self._stream_consumption_thread: Optional[Thread] = None
        self._state_change_lock = Lock()
        self._video_source_properties = video_source_properties or {}
        self._source_id = source_id
        self._last_frame_timestamp: int = time.time_ns()
        self._fps: Optional[float] = None
        self._is_file: Optional[bool] = None

    @property
    def source_id(self) -> Optional[int]:
        return self._source_id

    @lock_state_transition
    def restart(
        self, wait_on_frames_consumption: bool = True, purge_frames_buffer: bool = False
    ) -> None:
        """
        Method to restart source consumption. Eligible to be used in states:
        [MUTED, RUNNING, PAUSED, ENDED, ERROR].
        End state:
        * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed
        * ERROR - if it was not possible to connect with source

        Thread safe - only one transition of states possible at the time.

        Args:
            wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before
                completion of this operation.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
            * SourceConnectionError: if source cannot be connected
        """
        if self._state not in RESTART_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not RESTART stream in state: {self._state}"
            )
        self._restart(
            wait_on_frames_consumption=wait_on_frames_consumption,
            purge_frames_buffer=purge_frames_buffer,
        )

    @lock_state_transition
    def start(self) -> None:
        """
        Method to be used to start source consumption. Eligible to be used in states:
        [NOT_STARTED, ENDED, (RESTARTING - which is internal state only)]
        End state:
        * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed
        * ERROR - if it was not possible to connect with source

        Thread safe - only one transition of states possible at the time.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
            * SourceConnectionError: if source cannot be connected
        """
        if self._state not in START_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not START stream in state: {self._state}"
            )
        self._start()

    @lock_state_transition
    def terminate(
        self, wait_on_frames_consumption: bool = True, purge_frames_buffer: bool = False
    ) -> None:
        """
        Method to be used to terminate source consumption. Eligible to be used in states:
        [MUTED, RUNNING, PAUSED, ENDED, ERROR, (RESTARTING - which is internal state only)]
        End state:
        * ENDED - indicating success of the process
        * ERROR - if error with processing occurred

        Must be used to properly dispose resources at the end.

        Thread safe - only one transition of states possible at the time.

        Args:
            wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before
                completion of this operation.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        """
        if self._state not in TERMINATE_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not TERMINATE stream in state: {self._state}"
            )
        self._terminate(
            wait_on_frames_consumption=wait_on_frames_consumption,
            purge_frames_buffer=purge_frames_buffer,
        )

    @lock_state_transition
    def pause(self) -> None:
        """
        Method to be used to pause source consumption. During pause - no new frames are consumed.
        Used on on-line streams for too long may cause stream disconnection.
        Eligible to be used in states:
        [RUNNING]
        End state:
        * PAUSED

        Thread safe - only one transition of states possible at the time.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        """
        if self._state not in PAUSE_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not PAUSE stream in state: {self._state}"
            )
        self._pause()

    @lock_state_transition
    def mute(self) -> None:
        """
        Method to be used to mute source consumption. Muting is an equivalent of pause for stream - where
        frames grabbing is not put on hold, just new frames decoding and buffering is not allowed - causing
        intermediate frames to be dropped. May be also used against files, although arguably less useful.
        Eligible to be used in states:
        [RUNNING]
        End state:
        * MUTED

        Thread safe - only one transition of states possible at the time.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        """
        if self._state not in MUTE_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not MUTE stream in state: {self._state}"
            )
        self._mute()

    @lock_state_transition
    def resume(self) -> None:
        """
        Method to recover from pause or mute into running state.
        [PAUSED, MUTED]
        End state:
        * RUNNING

        Thread safe - only one transition of states possible at the time.

        Returns: None
        Throws:
            * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        """
        if self._state not in RESUME_ELIGIBLE_STATES:
            raise StreamOperationNotAllowedError(
                f"Could not RESUME stream in state: {self._state}"
            )
        self._resume()

    def get_state(self) -> StreamState:
        """
        Method to get current state of the `VideoSource`

        Returns: StreamState
        """
        return self._state

    def frame_ready(self) -> bool:
        """
        Method to check if decoded frame is ready for consumer

        Returns: boolean flag indicating frame readiness
        """
        return not self._frames_buffer.empty()

    def read_frame(self, timeout: Optional[float] = None) -> Optional[VideoFrame]:
        """
        Method to be used by the consumer to get decoded source frame.

        Returns: VideoFrame object with decoded frame and its metadata.
        Throws:
            * EndOfStreamError: when trying to get the frame from closed source.
        """
        if self._is_file is None:
            source_metadata: SourceMetadata = self.describe_source()
            self._is_file = source_metadata.source_properties.is_file
            self._fps = source_metadata.source_properties.fps
            if not self._fps or self._fps <= 0 or self._fps > 1000:
                self._fps = 30  # sane default
        video_frame: Optional[Union[VideoFrame, str]] = get_from_queue(
            queue=self._frames_buffer,
            on_successful_read=self._video_consumer.notify_frame_consumed,
            timeout=timeout,
            purge=self._buffer_consumption_strategy is BufferConsumptionStrategy.EAGER,
        )
        if video_frame == POISON_PILL:
            raise EndOfStreamError(
                "Attempted to retrieve frame from stream that already ended."
            )
        if video_frame is not None and self._status_update_handlers:
            send_video_source_status_update(
                severity=UpdateSeverity.DEBUG,
                event_type=FRAME_CONSUMED_EVENT,
                payload={
                    "frame_timestamp": video_frame.frame_timestamp,
                    "frame_id": video_frame.frame_id,
                    "source_id": video_frame.source_id,
                },
                status_update_handlers=self._status_update_handlers,
            )
        return video_frame

    def describe_source(self) -> SourceMetadata:
        serialized_source_reference = self._stream_reference
        if callable(serialized_source_reference):
            serialized_source_reference = str(self._stream_reference)
        return SourceMetadata(
            source_properties=self._source_properties,
            source_reference=serialized_source_reference,
            buffer_size=self._frames_buffer.maxsize,
            state=self._state,
            buffer_filling_strategy=self._video_consumer.buffer_filling_strategy,
            buffer_consumption_strategy=self._buffer_consumption_strategy,
            source_id=self._source_id,
        )

    def _restart(
        self, wait_on_frames_consumption: bool = True, purge_frames_buffer: bool = False
    ) -> None:
        self._terminate(
            wait_on_frames_consumption=wait_on_frames_consumption,
            purge_frames_buffer=purge_frames_buffer,
        )
        self._change_state(target_state=StreamState.RESTARTING)
        self._playback_allowed = Event()
        self._frames_buffering_allowed = True
        self._video: Optional[VideoFrameProducer] = None
        self._source_properties: Optional[SourceProperties] = None
        self._start()

    def _start(self) -> None:
        self._change_state(target_state=StreamState.INITIALISING)
        if callable(self._stream_reference):
            self._video = self._stream_reference()
        elif _is_test_pattern_reference(self._stream_reference):
            from inference.core.interfaces.camera.test_pattern_producer import (
                TestPatternStreamProducer,
            )

            self._video = TestPatternStreamProducer()
        else:
            self._video = CV2VideoFrameProducer(self._stream_reference)
        if not self._video.isOpened():
            self._change_state(target_state=StreamState.ERROR)
            raise SourceConnectionError(
                f"Cannot connect to video source under reference: {self._stream_reference}"
            )
        self._video.initialize_source_properties(self._video_source_properties)
        self._source_properties = self._video.discover_source_properties()
        self._video_consumer.reset(source_properties=self._source_properties)
        if self._source_properties.is_file:
            self._set_file_mode_consumption_strategies()
        else:
            self._set_stream_mode_consumption_strategies()
        self._playback_allowed.set()
        self._stream_consumption_thread = Thread(target=self._consume_video)
        self._stream_consumption_thread.start()

    def _terminate(
        self, wait_on_frames_consumption: bool, purge_frames_buffer: bool
    ) -> None:
        if self._state in RESUME_ELIGIBLE_STATES:
            self._resume()
        previous_state = self._state
        self._change_state(target_state=StreamState.TERMINATING)
        if purge_frames_buffer:
            _ = get_from_queue(queue=self._frames_buffer, timeout=0.0, purge=True)
        if self._stream_consumption_thread is not None:
            self._stream_consumption_thread.join()
        if wait_on_frames_consumption:
            self._frames_buffer.join()
        if previous_state is not StreamState.ERROR:
            self._change_state(target_state=StreamState.ENDED)

    def _pause(self) -> None:
        self._playback_allowed.clear()
        self._change_state(target_state=StreamState.PAUSED)

    def _mute(self) -> None:
        self._frames_buffering_allowed = False
        self._change_state(target_state=StreamState.MUTED)

    def _resume(self) -> None:
        previous_state = self._state
        self._change_state(target_state=StreamState.RUNNING)
        if previous_state is StreamState.PAUSED:
            self._video_consumer.reset_stream_consumption_pace()
            self._playback_allowed.set()
        if previous_state is StreamState.MUTED:
            self._frames_buffering_allowed = True

    def _set_file_mode_consumption_strategies(self) -> None:
        if self._buffer_consumption_strategy is None:
            self._buffer_consumption_strategy = BufferConsumptionStrategy.LAZY

    def _set_stream_mode_consumption_strategies(self) -> None:
        if self._buffer_consumption_strategy is None:
            self._buffer_consumption_strategy = BufferConsumptionStrategy.EAGER

    def _consume_video(self) -> None:
        send_video_source_status_update(
            severity=UpdateSeverity.INFO,
            event_type=VIDEO_CONSUMPTION_STARTED_EVENT,
            status_update_handlers=self._status_update_handlers,
            payload={"source_id": self._source_id},
        )
        logger.info(f"Video consumption started")
        try:
            if self._state is not StreamState.TERMINATING:
                self._change_state(target_state=StreamState.RUNNING)
            declared_source_fps, is_video_file = None, None
            if self._source_properties is not None:
                declared_source_fps = self._source_properties.fps
                is_video_file = self._source_properties.is_file
            while self._video.isOpened():
                if self._state is StreamState.TERMINATING:
                    break
                self._playback_allowed.wait()
                success = self._video_consumer.consume_frame(
                    video=self._video,
                    declared_source_fps=declared_source_fps,
                    is_source_video_file=is_video_file,
                    buffer=self._frames_buffer,
                    frames_buffering_allowed=self._frames_buffering_allowed,
                    source_id=self._source_id,
                )
                if not success:
                    break
            self._frames_buffer.put(POISON_PILL)
            self._video.release()
            self._change_state(target_state=StreamState.ENDED)
            send_video_source_status_update(
                severity=UpdateSeverity.INFO,
                event_type=VIDEO_CONSUMPTION_FINISHED_EVENT,
                status_update_handlers=self._status_update_handlers,
                payload={"source_id": self._source_id},
            )
            logger.info(f"Video consumption finished")
        except Exception as error:
            self._change_state(target_state=StreamState.ERROR)
            payload = {
                "source_id": self._source_id,
                "error_type": error.__class__.__name__,
                "error_message": str(error),
                "error_context": "stream_consumer_thread",
            }
            send_video_source_status_update(
                severity=UpdateSeverity.ERROR,
                event_type=SOURCE_ERROR_EVENT,
                payload=payload,
                status_update_handlers=self._status_update_handlers,
            )
            logger.exception("Encountered error in video consumption thread")

    def _change_state(self, target_state: StreamState) -> None:
        payload = {
            "previous_state": self._state,
            "new_state": target_state,
            "source_id": self._source_id,
        }
        self._state = target_state
        send_video_source_status_update(
            severity=UpdateSeverity.INFO,
            event_type=SOURCE_STATE_UPDATE_EVENT,
            payload=payload,
            status_update_handlers=self._status_update_handlers,
        )

    def __iter__(self) -> "VideoSource":
        return self

    def __next__(self) -> VideoFrame:
        """
        Method allowing to use `VideoSource` convenient to read frames

        Returns: VideoFrame

        Example:
            ```python
            source = VideoSource.init(video_reference="./some.mp4")
            source.start()

            for frame in source:
                 pass
            ```
        """
        try:
            return self.read_frame()
        except EndOfStreamError:
            raise StopIteration()
Functions
__next__
__next__()

Method allowing to use VideoSource convenient to read frames

Returns: VideoFrame

Example
source = VideoSource.init(video_reference="./some.mp4")
source.start()

for frame in source:
     pass
Source code in inference/core/interfaces/camera/video_source.py
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
def __next__(self) -> VideoFrame:
    """
    Method allowing to use `VideoSource` convenient to read frames

    Returns: VideoFrame

    Example:
        ```python
        source = VideoSource.init(video_reference="./some.mp4")
        source.start()

        for frame in source:
             pass
        ```
    """
    try:
        return self.read_frame()
    except EndOfStreamError:
        raise StopIteration()
frame_ready
frame_ready()

Method to check if decoded frame is ready for consumer

Returns: boolean flag indicating frame readiness

Source code in inference/core/interfaces/camera/video_source.py
537
538
539
540
541
542
543
def frame_ready(self) -> bool:
    """
    Method to check if decoded frame is ready for consumer

    Returns: boolean flag indicating frame readiness
    """
    return not self._frames_buffer.empty()
get_state
get_state()

Method to get current state of the VideoSource

Returns: StreamState

Source code in inference/core/interfaces/camera/video_source.py
529
530
531
532
533
534
535
def get_state(self) -> StreamState:
    """
    Method to get current state of the `VideoSource`

    Returns: StreamState
    """
    return self._state
init classmethod
init(
    video_reference,
    buffer_size=DEFAULT_BUFFER_SIZE,
    status_update_handlers=None,
    buffer_filling_strategy=None,
    buffer_consumption_strategy=None,
    adaptive_mode_stream_pace_tolerance=DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE,
    adaptive_mode_reader_pace_tolerance=DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE,
    minimum_adaptive_mode_samples=DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES,
    maximum_adaptive_frames_dropped_in_row=DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW,
    video_source_properties=None,
    source_id=None,
    desired_fps=None,
)

This class is meant to represent abstraction over video sources - both video files and on-line streams that are possible to be consumed and used by other components of inference library.

Before digging into details of the class behaviour, it is advised to familiarise with the following concepts and implementation assumptions:

  1. Video file can be accessed from local (or remote) storage by the consumer in a pace dictated by its processing capabilities. If processing is faster than the frame rate of video, operations may be executed in a time shorter than the time of video playback. In the opposite case - consumer may freely decode and process frames in its own pace, without risk for failures due to temporal dependencies of processing - this is classical offline processing example.
  2. Video streams, on the other hand, usually need to be consumed in a pace near to their frame-rate - in other words - this is on-line processing example. Consumer being faster than incoming stream frames cannot utilise its resources to the full extent as not-yet-delivered data would be needed. Slow consumer, however, may not be able to process everything on time and to keep up with the pace of stream - some frames would need to be dropped. Otherwise - over time, consumer could go out of sync with the stream causing decoding failures or unpredictable behavior.

To fit those two types of video sources, VideoSource introduces the concept of buffered decoding of video stream (like at the YouTube - player buffers some frames that are soon to be displayed). The way on how buffer is filled and consumed dictates the behavior of VideoSource.

Starting from BufferFillingStrategy - we have 3 basic options: * WAIT: in case of slow video consumption, when buffer is full - VideoSource will wait for the empty spot in buffer before next frame will be processed - this is suitable in cases when we want to ensure EACH FRAME of the video to be processed * DROP_OLDEST: when buffer is full, the frame that sits there for the longest time will be dropped - this is suitable for cases when we want to process the most recent frames possible * DROP_LATEST: when buffer is full, the newly decoded frame is dropped - useful in cases when it is expected to have processing performance drops, but we would like to consume portions of video that are locally smooth - but this is probably the least common use-case.

On top of that - there are two ADAPTIVE strategies: ADAPTIVE_DROP_OLDEST and ADAPTIVE_DROP_LATEST, which are equivalent to DROP_OLDEST and DROP_LATEST with adaptive decoding feature enabled. The notion of that mode will be described later.

Naturally, decoded frames must also be consumed. VideoSource provides a handy interface for reading a video source frames by a SINGLE consumer. Consumption strategy can also be dictated via BufferConsumptionStrategy: * LAZY - consume all the frames from decoding buffer one-by-one * EAGER - at each readout - take all frames already buffered, drop all of them apart from the most recent

In consequence - there are various combinations of BufferFillingStrategy and BufferConsumptionStrategy. The most popular would be: * BufferFillingStrategy.WAIT and BufferConsumptionStrategy.LAZY - to always decode and process each and every frame of the source (useful while processing video files - and default behaviour enforced by inference if there is no explicit configuration) * BufferFillingStrategy.DROP_OLDEST and BufferConsumptionStrategy.EAGER - to always process the most recent frames of source (useful while processing video streams when low latency [real-time experience] is required - ADAPTIVE version of this is default for streams)

ADAPTIVE strategies were introduced to handle corner-cases, when consumer hardware is not capable to consume video stream and process frames at the same time (for instance - Nvidia Jetson devices running processing against hi-res streams with high FPS ratio). It acts with buffer in nearly the same way as DROP_OLDEST and DROP_LATEST strategies, but there are two more conditions that may influence frame drop: * announced rate of source - which in fact dictate the pace of frames grabbing from incoming stream that MUST be met by consumer to avoid strange decoding issues causing decoder to fail - if the pace of frame grabbing deviates too much - decoding will be postponed, and frames dropped to grab next ones sooner * consumption rate - in resource constraints environment, not only decoding is problematic from the performance perspective - but also heavy processing. If consumer is not quick enough - allocating more useful resources for decoding frames that may never be processed is a waste. That's why - if decoding happens more frequently than consumption of frame - ADAPTIVE mode causes decoding to be done in a slower pace and more frames are just grabbed and dropped on the floor. ADAPTIVE mode increases latency slightly, but may be the only way to operate in some cases. Behaviour of adaptive mode, including the maximum acceptable deviations of frames grabbing pace from source, reader pace and maximum number of consecutive frames dropped in ADAPTIVE mode are configurable by clients, with reasonable defaults being set.

VideoSource emits events regarding its activity - which can be intercepted by custom handlers. Take into account that they are always executed in context of thread invoking them (and should be fast to complete, otherwise may block the flow of stream consumption). All errors raised will be emitted as logger warnings only.

VideoSource implementation is naturally multithreading, with different thread decoding video and different one consuming it and manipulating source state. Implementation of user interface is thread-safe, although stream it is meant to be consumed by a single thread only.

ENV variables involved: * VIDEO_SOURCE_BUFFER_SIZE - default: 64 * VIDEO_SOURCE_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE - default: 0.1 * VIDEO_SOURCE_ADAPTIVE_MODE_READER_PACE_TOLERANCE - default: 5.0 * VIDEO_SOURCE_MINIMUM_ADAPTIVE_MODE_SAMPLES - default: 10 * VIDEO_SOURCE_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW - default: 16

As an inference user, please use .init() method instead of constructor to instantiate objects.

Parameters:

Name Type Description Default
video_reference Union[str, int]

Either str with file or stream reference, or int representing device ID

required
buffer_size int

size of decoding buffer

DEFAULT_BUFFER_SIZE
status_update_handlers Optional[List[Callable[[StatusUpdate], None]]]

List of handlers for status updates

None
buffer_filling_strategy Optional[BufferFillingStrategy]

Settings for buffer filling strategy - if not given - automatic choice regarding source type will be applied

None
buffer_consumption_strategy Optional[BufferConsumptionStrategy]

Settings for buffer consumption strategy, if not given - automatic choice regarding source type will be applied

None
adaptive_mode_stream_pace_tolerance float

Maximum deviation between frames grabbing pace and stream pace that will not trigger adaptive mode frame drop

DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE
adaptive_mode_reader_pace_tolerance float

Maximum deviation between decoding pace and stream consumption pace that will not trigger adaptive mode frame drop

DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE
minimum_adaptive_mode_samples int

Minimal number of frames to be used to establish actual pace of processing, before adaptive mode can drop any frame

DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES
maximum_adaptive_frames_dropped_in_row int

Maximum number of frames dropped in row due to application of adaptive strategy

DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW
video_source_properties Optional[dict[str, float]]

Optional dictionary with video source properties corresponding to OpenCV VideoCapture properties cv2.CAP_PROP_* to set values for the video source.

None
source_id Optional[int]

Optional identifier of video source - mainly useful to recognise specific source when multiple ones are in use. Identifier will be added to emitted frames and updates. It is advised to keep it unique within all sources in use.

None
Source code in inference/core/interfaces/camera/video_source.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@classmethod
def init(
    cls,
    video_reference: VideoSourceIdentifier,
    buffer_size: int = DEFAULT_BUFFER_SIZE,
    status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None,
    buffer_filling_strategy: Optional[BufferFillingStrategy] = None,
    buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None,
    adaptive_mode_stream_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE,
    adaptive_mode_reader_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE,
    minimum_adaptive_mode_samples: int = DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES,
    maximum_adaptive_frames_dropped_in_row: int = DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW,
    video_source_properties: Optional[Dict[str, float]] = None,
    source_id: Optional[int] = None,
    desired_fps: Optional[Union[float, int]] = None,
):
    """
    This class is meant to represent abstraction over video sources - both video files and
    on-line streams that are possible to be consumed and used by other components of `inference`
    library.

    Before digging into details of the class behaviour, it is advised to familiarise with the following
    concepts and implementation assumptions:

    1. Video file can be accessed from local (or remote) storage by the consumer in a pace dictated by
        its processing capabilities. If processing is faster than the frame rate of video, operations
        may be executed in a time shorter than the time of video playback. In the opposite case - consumer
        may freely decode and process frames in its own pace, without risk for failures due to temporal
        dependencies of processing - this is classical offline processing example.
    2. Video streams, on the other hand, usually need to be consumed in a pace near to their frame-rate -
        in other words - this is on-line processing example. Consumer being faster than incoming stream
        frames cannot utilise its resources to the full extent as not-yet-delivered data would be needed.
        Slow consumer, however, may not be able to process everything on time and to keep up with the pace
        of stream - some frames would need to be dropped. Otherwise - over time, consumer could go out of
        sync with the stream causing decoding failures or unpredictable behavior.

    To fit those two types of video sources, `VideoSource` introduces the concept of buffered decoding of
    video stream (like at the YouTube - player buffers some frames that are soon to be displayed).
    The way on how buffer is filled and consumed dictates the behavior of `VideoSource`.

    Starting from `BufferFillingStrategy` - we have 3 basic options:
    * WAIT: in case of slow video consumption, when buffer is full - `VideoSource` will wait for
    the empty spot in buffer before next frame will be processed - this is suitable in cases when
    we want to ensure EACH FRAME of the video to be processed
    * DROP_OLDEST: when buffer is full, the frame that sits there for the longest time will be dropped -
    this is suitable for cases when we want to process the most recent frames possible
    * DROP_LATEST: when buffer is full, the newly decoded frame is dropped - useful in cases when
    it is expected to have processing performance drops, but we would like to consume portions of
    video that are locally smooth - but this is probably the least common use-case.

    On top of that - there are two ADAPTIVE strategies: ADAPTIVE_DROP_OLDEST and ADAPTIVE_DROP_LATEST,
    which are equivalent to DROP_OLDEST and DROP_LATEST with adaptive decoding feature enabled. The notion
    of that mode will be described later.

    Naturally, decoded frames must also be consumed. `VideoSource` provides a handy interface for reading
    a video source frames by a SINGLE consumer. Consumption strategy can also be dictated via
    `BufferConsumptionStrategy`:
    * LAZY - consume all the frames from decoding buffer one-by-one
    * EAGER - at each readout - take all frames already buffered, drop all of them apart from the most recent

    In consequence - there are various combinations of `BufferFillingStrategy` and `BufferConsumptionStrategy`.
    The most popular would be:
    * `BufferFillingStrategy.WAIT` and `BufferConsumptionStrategy.LAZY` - to always decode and process each and
        every frame of the source (useful while processing video files - and default behaviour enforced by
        `inference` if there is no explicit configuration)
    * `BufferFillingStrategy.DROP_OLDEST` and `BufferConsumptionStrategy.EAGER` - to always process the most
        recent frames of source (useful while processing video streams when low latency [real-time experience]
        is required - ADAPTIVE version of this is default for streams)

    ADAPTIVE strategies were introduced to handle corner-cases, when consumer hardware is not capable to consume
    video stream and process frames at the same time (for instance - Nvidia Jetson devices running processing
    against hi-res streams with high FPS ratio). It acts with buffer in nearly the same way as `DROP_OLDEST`
    and `DROP_LATEST` strategies, but there are two more conditions that may influence frame drop:
    * announced rate of source - which in fact dictate the pace of frames grabbing from incoming stream that
    MUST be met by consumer to avoid strange decoding issues causing decoder to fail - if the pace of frame grabbing
    deviates too much - decoding will be postponed, and frames dropped to grab next ones sooner
    * consumption rate - in resource constraints environment, not only decoding is problematic from the performance
    perspective - but also heavy processing. If consumer is not quick enough - allocating more useful resources
    for decoding frames that may never be processed is a waste. That's why - if decoding happens more frequently
    than consumption of frame - ADAPTIVE mode causes decoding to be done in a slower pace and more frames are just
    grabbed and dropped on the floor.
    ADAPTIVE mode increases latency slightly, but may be the only way to operate in some cases.
    Behaviour of adaptive mode, including the maximum acceptable deviations of frames grabbing pace from source,
    reader pace and maximum number of consecutive frames dropped in ADAPTIVE mode are configurable by clients,
    with reasonable defaults being set.

    `VideoSource` emits events regarding its activity - which can be intercepted by custom handlers. Take
    into account that they are always executed in context of thread invoking them (and should be fast to complete,
    otherwise may block the flow of stream consumption). All errors raised will be emitted as logger warnings only.

    `VideoSource` implementation is naturally multithreading, with different thread decoding video and different
    one consuming it and manipulating source state. Implementation of user interface is thread-safe, although
    stream it is meant to be consumed by a single thread only.

    ENV variables involved:
    * VIDEO_SOURCE_BUFFER_SIZE - default: 64
    * VIDEO_SOURCE_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE - default: 0.1
    * VIDEO_SOURCE_ADAPTIVE_MODE_READER_PACE_TOLERANCE - default: 5.0
    * VIDEO_SOURCE_MINIMUM_ADAPTIVE_MODE_SAMPLES - default: 10
    * VIDEO_SOURCE_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW - default: 16

    As an `inference` user, please use .init() method instead of constructor to instantiate objects.

    Args:
        video_reference (Union[str, int]): Either str with file or stream reference, or int representing device ID
        buffer_size (int): size of decoding buffer
        status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers for status updates
        buffer_filling_strategy (Optional[BufferFillingStrategy]): Settings for buffer filling strategy - if not
            given - automatic choice regarding source type will be applied
        buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Settings for buffer consumption strategy,
            if not given - automatic choice regarding source type will be applied
        adaptive_mode_stream_pace_tolerance (float): Maximum deviation between frames grabbing pace and stream pace
            that will not trigger adaptive mode frame drop
        adaptive_mode_reader_pace_tolerance (float): Maximum deviation between decoding pace and stream consumption
            pace that will not trigger adaptive mode frame drop
        minimum_adaptive_mode_samples (int): Minimal number of frames to be used to establish actual pace of
            processing, before adaptive mode can drop any frame
        maximum_adaptive_frames_dropped_in_row (int): Maximum number of frames dropped in row due to application of
            adaptive strategy
        video_source_properties (Optional[dict[str, float]]): Optional dictionary with video source properties
            corresponding to OpenCV VideoCapture properties cv2.CAP_PROP_* to set values for the video source.
        source_id (Optional[int]): Optional identifier of video source - mainly useful to recognise specific source
            when multiple ones are in use. Identifier will be added to emitted frames and updates. It is advised
            to keep it unique within all sources in use.

    Returns: Instance of `VideoSource` class
    """
    frames_buffer = Queue(maxsize=buffer_size)
    if status_update_handlers is None:
        status_update_handlers = []
    video_consumer = VideoConsumer.init(
        buffer_filling_strategy=buffer_filling_strategy,
        adaptive_mode_stream_pace_tolerance=adaptive_mode_stream_pace_tolerance,
        adaptive_mode_reader_pace_tolerance=adaptive_mode_reader_pace_tolerance,
        minimum_adaptive_mode_samples=minimum_adaptive_mode_samples,
        maximum_adaptive_frames_dropped_in_row=maximum_adaptive_frames_dropped_in_row,
        status_update_handlers=status_update_handlers,
        desired_fps=desired_fps,
    )
    return cls(
        stream_reference=video_reference,
        frames_buffer=frames_buffer,
        status_update_handlers=status_update_handlers,
        buffer_consumption_strategy=buffer_consumption_strategy,
        video_consumer=video_consumer,
        video_source_properties=video_source_properties,
        source_id=source_id,
    )
mute
mute()

Method to be used to mute source consumption. Muting is an equivalent of pause for stream - where frames grabbing is not put on hold, just new frames decoding and buffering is not allowed - causing intermediate frames to be dropped. May be also used against files, although arguably less useful. Eligible to be used in states: [RUNNING] End state: * MUTED

Thread safe - only one transition of states possible at the time.

Source code in inference/core/interfaces/camera/video_source.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
@lock_state_transition
def mute(self) -> None:
    """
    Method to be used to mute source consumption. Muting is an equivalent of pause for stream - where
    frames grabbing is not put on hold, just new frames decoding and buffering is not allowed - causing
    intermediate frames to be dropped. May be also used against files, although arguably less useful.
    Eligible to be used in states:
    [RUNNING]
    End state:
    * MUTED

    Thread safe - only one transition of states possible at the time.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
    """
    if self._state not in MUTE_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not MUTE stream in state: {self._state}"
        )
    self._mute()
pause
pause()

Method to be used to pause source consumption. During pause - no new frames are consumed. Used on on-line streams for too long may cause stream disconnection. Eligible to be used in states: [RUNNING] End state: * PAUSED

Thread safe - only one transition of states possible at the time.

Source code in inference/core/interfaces/camera/video_source.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@lock_state_transition
def pause(self) -> None:
    """
    Method to be used to pause source consumption. During pause - no new frames are consumed.
    Used on on-line streams for too long may cause stream disconnection.
    Eligible to be used in states:
    [RUNNING]
    End state:
    * PAUSED

    Thread safe - only one transition of states possible at the time.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
    """
    if self._state not in PAUSE_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not PAUSE stream in state: {self._state}"
        )
    self._pause()
read_frame
read_frame(timeout=None)

Method to be used by the consumer to get decoded source frame.

Source code in inference/core/interfaces/camera/video_source.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def read_frame(self, timeout: Optional[float] = None) -> Optional[VideoFrame]:
    """
    Method to be used by the consumer to get decoded source frame.

    Returns: VideoFrame object with decoded frame and its metadata.
    Throws:
        * EndOfStreamError: when trying to get the frame from closed source.
    """
    if self._is_file is None:
        source_metadata: SourceMetadata = self.describe_source()
        self._is_file = source_metadata.source_properties.is_file
        self._fps = source_metadata.source_properties.fps
        if not self._fps or self._fps <= 0 or self._fps > 1000:
            self._fps = 30  # sane default
    video_frame: Optional[Union[VideoFrame, str]] = get_from_queue(
        queue=self._frames_buffer,
        on_successful_read=self._video_consumer.notify_frame_consumed,
        timeout=timeout,
        purge=self._buffer_consumption_strategy is BufferConsumptionStrategy.EAGER,
    )
    if video_frame == POISON_PILL:
        raise EndOfStreamError(
            "Attempted to retrieve frame from stream that already ended."
        )
    if video_frame is not None and self._status_update_handlers:
        send_video_source_status_update(
            severity=UpdateSeverity.DEBUG,
            event_type=FRAME_CONSUMED_EVENT,
            payload={
                "frame_timestamp": video_frame.frame_timestamp,
                "frame_id": video_frame.frame_id,
                "source_id": video_frame.source_id,
            },
            status_update_handlers=self._status_update_handlers,
        )
    return video_frame
restart
restart(
    wait_on_frames_consumption=True,
    purge_frames_buffer=False,
)

Method to restart source consumption. Eligible to be used in states: [MUTED, RUNNING, PAUSED, ENDED, ERROR]. End state: * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed * ERROR - if it was not possible to connect with source

Thread safe - only one transition of states possible at the time.

Parameters:

Name Type Description Default
wait_on_frames_consumption bool

Flag telling if all frames from buffer must be consumed before completion of this operation.

True
Source code in inference/core/interfaces/camera/video_source.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@lock_state_transition
def restart(
    self, wait_on_frames_consumption: bool = True, purge_frames_buffer: bool = False
) -> None:
    """
    Method to restart source consumption. Eligible to be used in states:
    [MUTED, RUNNING, PAUSED, ENDED, ERROR].
    End state:
    * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed
    * ERROR - if it was not possible to connect with source

    Thread safe - only one transition of states possible at the time.

    Args:
        wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before
            completion of this operation.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        * SourceConnectionError: if source cannot be connected
    """
    if self._state not in RESTART_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not RESTART stream in state: {self._state}"
        )
    self._restart(
        wait_on_frames_consumption=wait_on_frames_consumption,
        purge_frames_buffer=purge_frames_buffer,
    )
resume
resume()

Method to recover from pause or mute into running state. [PAUSED, MUTED] End state: * RUNNING

Thread safe - only one transition of states possible at the time.

Source code in inference/core/interfaces/camera/video_source.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
@lock_state_transition
def resume(self) -> None:
    """
    Method to recover from pause or mute into running state.
    [PAUSED, MUTED]
    End state:
    * RUNNING

    Thread safe - only one transition of states possible at the time.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
    """
    if self._state not in RESUME_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not RESUME stream in state: {self._state}"
        )
    self._resume()
start
start()

Method to be used to start source consumption. Eligible to be used in states: [NOT_STARTED, ENDED, (RESTARTING - which is internal state only)] End state: * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed * ERROR - if it was not possible to connect with source

Thread safe - only one transition of states possible at the time.

Source code in inference/core/interfaces/camera/video_source.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
@lock_state_transition
def start(self) -> None:
    """
    Method to be used to start source consumption. Eligible to be used in states:
    [NOT_STARTED, ENDED, (RESTARTING - which is internal state only)]
    End state:
    * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed
    * ERROR - if it was not possible to connect with source

    Thread safe - only one transition of states possible at the time.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
        * SourceConnectionError: if source cannot be connected
    """
    if self._state not in START_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not START stream in state: {self._state}"
        )
    self._start()
terminate
terminate(
    wait_on_frames_consumption=True,
    purge_frames_buffer=False,
)

Method to be used to terminate source consumption. Eligible to be used in states: [MUTED, RUNNING, PAUSED, ENDED, ERROR, (RESTARTING - which is internal state only)] End state: * ENDED - indicating success of the process * ERROR - if error with processing occurred

Must be used to properly dispose resources at the end.

Thread safe - only one transition of states possible at the time.

Parameters:

Name Type Description Default
wait_on_frames_consumption bool

Flag telling if all frames from buffer must be consumed before completion of this operation.

True
Source code in inference/core/interfaces/camera/video_source.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
@lock_state_transition
def terminate(
    self, wait_on_frames_consumption: bool = True, purge_frames_buffer: bool = False
) -> None:
    """
    Method to be used to terminate source consumption. Eligible to be used in states:
    [MUTED, RUNNING, PAUSED, ENDED, ERROR, (RESTARTING - which is internal state only)]
    End state:
    * ENDED - indicating success of the process
    * ERROR - if error with processing occurred

    Must be used to properly dispose resources at the end.

    Thread safe - only one transition of states possible at the time.

    Args:
        wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before
            completion of this operation.

    Returns: None
    Throws:
        * StreamOperationNotAllowedError: if executed in context of incorrect state of the source
    """
    if self._state not in TERMINATE_ELIGIBLE_STATES:
        raise StreamOperationNotAllowedError(
            f"Could not TERMINATE stream in state: {self._state}"
        )
    self._terminate(
        wait_on_frames_consumption=wait_on_frames_consumption,
        purge_frames_buffer=purge_frames_buffer,
    )

Functions

get_from_queue

get_from_queue(
    queue,
    timeout=None,
    on_successful_read=lambda: None,
    purge=False,
)

Function is supposed to take element from the queue waiting on the first element to appear using timeout parameter. One may ask to go to the very last element of the queue and return it - then purge should be set to True. No additional wait on new elements to appear happen and the purge stops once queue is free returning last element consumed. queue.task_done() and on_successful_read(...) will be called on each received element.

Source code in inference/core/interfaces/camera/video_source.py
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
def get_from_queue(
    queue: Queue,
    timeout: Optional[float] = None,
    on_successful_read: Callable[[], None] = lambda: None,
    purge: bool = False,
) -> Optional[Any]:
    """
    Function is supposed to take element from the queue waiting on the first element to appear using `timeout`
    parameter. One may ask to go to the very last element of the queue and return it - then `purge` should be set
    to True. No additional wait on new elements to appear happen and the purge stops once queue is free returning last
    element consumed.
    queue.task_done() and on_successful_read(...) will be called on each received element.
    """
    result = None
    if queue.empty() or not purge:
        try:
            result = queue.get(timeout=timeout)
            queue.task_done()
            on_successful_read()
        except Empty:
            pass
    while not queue.empty() and purge:
        result = queue.get()
        queue.task_done()
        on_successful_read()
    return result

core/interfaces/http/builder

inference.core.interfaces.http.builder.routes

Functions

builder_browse async

builder_browse()

Loads the main builder UI (editor.html). Injects the CSRF token and BUILDER_ORIGIN so the client can parse them on page load.

Source code in inference/core/interfaces/http/builder/routes.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@router.get(
    "",
    summary="Workflow Builder List",
    description="Loads the list of Workflows available for editing",
)
@with_route_exceptions_async
async def builder_browse():
    """
    Loads the main builder UI (editor.html).
    Injects the CSRF token and BUILDER_ORIGIN
    so the client can parse them on page load.
    """
    base_path = Path(__file__).parent
    file_path = base_path / "editor.html"
    content = file_path.read_text(encoding="utf-8")
    content = content.replace("{{BUILDER_ORIGIN}}", BUILDER_ORIGIN)
    content = content.replace("{{CSRF}}", csrf)

    return HTMLResponse(content)

builder_edit async

builder_edit(workflow_id)

Loads a specific workflow for editing.

Parameters:

Name Type Description Default
workflow_id str

The ID of the workflow to be edited.

required
Source code in inference/core/interfaces/http/builder/routes.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@router.get(
    "/edit/{workflow_id}",
    summary="Workflow Builder",
    description="Loads a specific workflow for editing",
)
@with_route_exceptions_async
async def builder_edit(workflow_id: str):
    """
    Loads a specific workflow for editing.

    Args:
        workflow_id (str): The ID of the workflow to be edited.
    """
    base_path = Path(__file__).parent
    file_path = base_path / "editor.html"
    content = file_path.read_text(encoding="utf-8")
    content = content.replace("{{BUILDER_ORIGIN}}", BUILDER_ORIGIN)
    content = content.replace("{{CSRF}}", csrf)

    return HTMLResponse(content)

builder_maybe_redirect async

builder_maybe_redirect(workflow_id)

If the workflow_id.json file exists, redirect to /build/edit/{workflow_id}. Otherwise, redirect back to /build.

Source code in inference/core/interfaces/http/builder/routes.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@router.get("/{workflow_id}", include_in_schema=False)
@with_route_exceptions_async
async def builder_maybe_redirect(workflow_id: str):
    """
    If the workflow_id.json file exists, redirect to /build/edit/{workflow_id}.
    Otherwise, redirect back to /build.
    """
    if not re.match(r"^[\w\-]+$", workflow_id):
        return RedirectResponse(url="/build", status_code=302)

    workflow_hash = sha256(workflow_id.encode()).hexdigest()
    file_path = workflow_local_dir / f"{workflow_hash}.json"
    if file_path.exists():
        return RedirectResponse(url=f"/build/edit/{workflow_id}", status_code=302)
    else:
        return RedirectResponse(url="/build", status_code=302)

builder_redirect async

builder_redirect()

If user hits /build/ with trailing slash, redirect to /build

Source code in inference/core/interfaces/http/builder/routes.py
70
71
72
73
74
75
@router.get("/", include_in_schema=False)
async def builder_redirect():
    """
    If user hits /build/ with trailing slash, redirect to /build
    """
    return RedirectResponse(url="/build", status_code=302)

create_or_overwrite_workflow async

create_or_overwrite_workflow(
    workflow_id, request_body=Body(...)
)

Create or overwrite a workflow's JSON file on disk. Protected by CSRF token check.

Source code in inference/core/interfaces/http/builder/routes.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
@router.post("/api/{workflow_id}", dependencies=[Depends(verify_csrf_token)])
@with_route_exceptions_async
async def create_or_overwrite_workflow(
    workflow_id: str, request_body: dict = Body(...)
):
    """
    Create or overwrite a workflow's JSON file on disk.
    Protected by CSRF token check.
    """
    if not re.match(r"^[\w\-]+$", workflow_id):
        return JSONResponse({"error": "invalid id"}, status_code=HTTP_400_BAD_REQUEST)

    workflow_local_dir.mkdir(parents=True, exist_ok=True)

    # If the body claims a different ID, treat that as a "rename".
    if request_body.get("id") and request_body.get("id") != workflow_id:
        old_id: str = request_body["id"]
        if not re.match(r"^[\w\-]+$", old_id):
            return JSONResponse(
                {"error": "invalid id"}, status_code=HTTP_400_BAD_REQUEST
            )

        old_workflow_hash = sha256(old_id.encode()).hexdigest()
        old_file_path = workflow_local_dir / f"{old_workflow_hash}.json"
        if old_file_path.exists():
            try:
                old_file_path.unlink()
            except Exception as e:
                logger.error(f"Error deleting {old_id} from {old_file_path}: {e}")
                return JSONResponse({"error": "unable to delete file"}, status_code=500)

    request_body["id"] = workflow_id

    workflow_hash = sha256(workflow_id.encode()).hexdigest()
    file_path = workflow_local_dir / f"{workflow_hash}.json"
    try:
        with file_path.open("w", encoding="utf-8") as f:
            json.dump(request_body, f, indent=2)
    except Exception as e:
        logger.error(f"Error writing JSON for {workflow_id} to {file_path}: {e}")
        return JSONResponse({"error": "unable to write file"}, status_code=500)

    return JSONResponse(
        {"message": f"Workflow '{workflow_id}' created/updated successfully."},
        status_code=HTTP_201_CREATED,
    )

delete_workflow async

delete_workflow(workflow_id)

Delete a workflow's JSON file from disk. Protected by CSRF token check.

Source code in inference/core/interfaces/http/builder/routes.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@router.delete("/api/{workflow_id}", dependencies=[Depends(verify_csrf_token)])
@with_route_exceptions_async
async def delete_workflow(workflow_id: str):
    """
    Delete a workflow's JSON file from disk.
    Protected by CSRF token check.
    """
    if not re.match(r"^[\w\-]+$", workflow_id):
        return JSONResponse({"error": "invalid id"}, status_code=HTTP_400_BAD_REQUEST)

    workflow_hash = sha256(workflow_id.encode()).hexdigest()
    file_path = workflow_local_dir / f"{workflow_hash}.json"
    if not file_path.exists():
        return JSONResponse({"error": "not found"}, status_code=HTTP_404_NOT_FOUND)

    try:
        file_path.unlink()
    except Exception as e:
        logger.error(f"Error deleting {workflow_id} from {file_path}: {e}")
        return JSONResponse({"error": "unable to delete file"}, status_code=500)

    return JSONResponse(
        {"message": f"Workflow '{workflow_id}' deleted successfully."}, status_code=200
    )

get_all_workflows async

get_all_workflows()

Returns JSON info about all .json files in {MODEL_CACHE_DIR}/workflow/local. Protected by CSRF token check.

Source code in inference/core/interfaces/http/builder/routes.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@router.get("/api", dependencies=[Depends(verify_csrf_token)])
@with_route_exceptions_async
async def get_all_workflows():
    """
    Returns JSON info about all .json files in {MODEL_CACHE_DIR}/workflow/local.
    Protected by CSRF token check.
    """
    data = {}
    for json_file in workflow_local_dir.glob("*.json"):
        stat_info = json_file.stat()
        try:
            with json_file.open("r", encoding="utf-8") as f:
                config_contents: Dict[str, Any] = json.load(f)
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON from {json_file}: {e}")
            continue

        data[config_contents.get("id", json_file.stem)] = {
            "createTime": {"_seconds": int(stat_info.st_ctime)},
            "updateTime": {"_seconds": int(stat_info.st_mtime)},
            "config": config_contents,
        }

    return Response(
        content=json.dumps({"data": data}, indent=4),
        media_type="application/json",
        status_code=200,
    )

get_workflow async

get_workflow(workflow_id)

Return JSON for workflow_id.json, or 404 if missing.

Source code in inference/core/interfaces/http/builder/routes.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@router.get("/api/{workflow_id}", dependencies=[Depends(verify_csrf_token)])
@with_route_exceptions_async
async def get_workflow(workflow_id: str):
    """
    Return JSON for workflow_id.json, or 404 if missing.
    """
    if not re.match(r"^[\w\-]+$", workflow_id):
        return JSONResponse({"error": "invalid id"}, status_code=HTTP_400_BAD_REQUEST)

    workflow_hash = sha256(workflow_id.encode()).hexdigest()
    file_path = workflow_local_dir / f"{workflow_hash}.json"
    if not file_path.exists():
        return JSONResponse({"error": "not found"}, status_code=HTTP_404_NOT_FOUND)

    stat_info = file_path.stat()
    try:
        with file_path.open("r", encoding="utf-8") as f:
            config_contents = json.load(f)
    except json.JSONDecodeError as e:
        logger.error(f"Error reading JSON for {workflow_id} from '{file_path}': {e}")
        return JSONResponse({"error": "invalid JSON"}, status_code=500)

    return Response(
        content=json.dumps(
            {
                "data": {
                    "createTime": int(stat_info.st_ctime),
                    "updateTime": int(stat_info.st_mtime),
                    "config": config_contents,
                }
            },
            indent=4,
        ),
        media_type="application/json",
        status_code=200,
    )

core/interfaces/http

inference.core.interfaces.http.error_handlers

Classes

Functions

with_route_exceptions

with_route_exceptions(route)

A decorator that wraps a FastAPI route to handle specific exceptions. If an exception is caught, it returns a JSON response with the error message.

Parameters:

Name Type Description Default
route Callable

The FastAPI route to be wrapped.

required

Returns:

Name Type Description
Callable

The wrapped route.

Source code in inference/core/interfaces/http/error_handlers.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def with_route_exceptions(route):
    """
    A decorator that wraps a FastAPI route to handle specific exceptions. If an exception
    is caught, it returns a JSON response with the error message.

    Args:
        route (Callable): The FastAPI route to be wrapped.

    Returns:
        Callable: The wrapped route.
    """

    @wraps(route)
    def wrapped_route(*args, **kwargs):
        try:
            return route(*args, **kwargs)
        except ContentTypeInvalid as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Invalid Content-Type header provided with request."
                },
            )
        except ContentTypeMissing as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={"message": "Content-Type header not provided with request."},
            )
        except InputImageLoadError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": f"Could not load input image. Cause: {error.get_public_error_details()}"
                },
            )
        except ModelInputError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": f"Error with model input. Cause: {error}",
                    "help_url": error.help_url,
                },
            )
        except InvalidModelIDError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={"message": "Invalid Model ID sent in request."},
            )
        except InvalidMaskDecodeArgument as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Invalid mask decode argument sent. tradeoff_factor must be in [0.0, 1.0], "
                    "mask_decode_mode: must be one of ['accurate', 'fast', 'tradeoff']"
                },
            )
        except MissingApiKeyError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Required Roboflow API key is missing. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except (
            WorkflowSyntaxError,
            InvalidReferenceTargetError,
            ExecutionGraphStructureError,
            StepInputDimensionalityError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=error.blocks_errors,
            )
            resp = JSONResponse(status_code=400, content=content.model_dump())
        except (
            WorkflowDefinitionError,
            ReferenceTypeError,
            RuntimeInputError,
            InvalidInputTypeError,
            OperationTypeNotRecognisedError,
            DynamicBlockError,
            WorkflowExecutionEngineVersionError,
            NotSupportedExecutionEngineError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "context": error.context,
                    "inner_error_type": error.inner_error_type,
                    "inner_error_message": str(error.inner_error),
                },
            )
        except (
            ProcessesManagerInvalidPayload,
            MalformedPayloadError,
            MessageToBigError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except (
            RoboflowAPINotAuthorizedError,
            ProcessesManagerAuthorisationError,
            UnauthorizedModelAccessError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=401,
                content={
                    "message": "Unauthorized access to roboflow API - check API key and make sure the key is valid for "
                    "workspace you use. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except PaymentRequiredError as error:
            logger.warning("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=402,
                content={
                    "message": "Not enough credits to perform this request. Verify your workspace billing page."
                },
            )
        except RoboflowAPIForbiddenError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=403,
                content={
                    "message": "Unauthorized access to roboflow API - check API key and make sure the key is valid and "
                    "have required scopes. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except RoboflowAPIUsagePausedError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=423,
                content={
                    "message": "Roboflow API usage is paused. Please contact your workspace administrator to re-enable api keys."
                },
            )
        except (RoboflowAPINotNotFoundError, ModelNotFoundError) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=404,
                content={
                    "message": "Requested Roboflow resource not found. Make sure that workspace, project or model "
                    "you referred in request exists."
                },
            )
        except ProcessesManagerNotFoundError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=404,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except ModelPackageNegotiationError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Could not negotiate model package - {error}",
                    "help_url": error.help_url,
                },
            )
        except (
            InvalidEnvironmentVariableError,
            MissingServiceSecretError,
            ServiceConfigurationError,
            EnvironmentConfigurationError,
            InvalidEnvVariable,
            JetsonTypeResolutionError,
            MissingDependencyError,
            InvalidParameterError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500, content={"message": "Service misconfiguration."}
            )
        except (
            PreProcessingError,
            PostProcessingError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": "Model configuration related to pre- or post-processing is invalid."
                },
            )
        except ModelArtefactError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500, content={"message": "Model package is broken."}
            )
        except ModelLoadingError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Model loading failed: {error}",
                    "help_url": error.help_url,
                },
            )
        except (UntrustedFileError, FileHashSumMissmatch) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Issue with model package file: {error}",
                    "help_url": error.help_url,
                },
            )
        except ModelRetrievalError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Could not retrieve model {error}",
                    "help_url": error.help_url,
                },
            )
        except OnnxProviderNotAvailable as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=501,
                content={
                    "message": "Could not find requested ONNX Runtime Provider. Check that you are using "
                    "the correct docker image on a supported device."
                },
            )
        except (
            MalformedRoboflowAPIResponseError,
            RoboflowAPIUnsuccessfulRequestError,
            WorkspaceLoadError,
            MalformedWorkflowResponseError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=502,
                content={"message": "Internal error. Request to Roboflow API failed."},
            )
        except InferenceModelNotFound as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={"message": "Model is temporarily not ready - retry request."},
                headers={"Retry-After": "1"},
            )
        except RoboflowAPIConnectionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={
                    "message": "Internal error. Could not connect to Roboflow API."
                },
            )
        except ModelManagerLockAcquisitionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={
                    "message": "Could not acquire model manager lock due to other request performing "
                    "blocking operation. Try again...."
                },
                headers={"Retry-After": "1"},
            )
        except RoboflowAPITimeoutError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=504,
                content={
                    "message": "Timeout when attempting to connect to Roboflow API."
                },
            )
        except ClientCausedStepExecutionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=[
                    WorkflowBlockError(
                        block_id=error.block_id,
                    ),
                ],
            )
            resp = JSONResponse(
                status_code=error.status_code,
                content=content.model_dump(),
            )
        except StepExecutionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=[
                    WorkflowBlockError(
                        block_id=error.block_id,
                        block_type=error.block_type,
                    ),
                ],
            )
            resp = JSONResponse(
                status_code=500,
                content=content.model_dump(),
            )
        except WorkflowError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "context": error.context,
                    "inner_error_type": error.inner_error_type,
                    "inner_error_message": str(error.inner_error),
                },
            )
        except (
            ProcessesManagerClientError,
            CommunicationProtocolError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except WebRTCConfigurationError as error:
            logger.error("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": str(error),
                    "error_type": "WebRTCConfigurationError",
                },
            )
        except CreditsExceededError as error:
            logger.error("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=402,
                content={
                    "message": "Not enough credits to perform this request.",
                    "error_type": "CreditsExceededError",
                },
            )
        except Exception as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(status_code=500, content={"message": "Internal error."})
        return resp

    return wrapped_route

with_route_exceptions_async

with_route_exceptions_async(route)

A decorator that wraps a FastAPI route to handle specific exceptions. If an exception is caught, it returns a JSON response with the error message.

Parameters:

Name Type Description Default
route Callable

The FastAPI route to be wrapped.

required

Returns:

Name Type Description
Callable

The wrapped route.

Source code in inference/core/interfaces/http/error_handlers.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def with_route_exceptions_async(route):
    """
    A decorator that wraps a FastAPI route to handle specific exceptions. If an exception
    is caught, it returns a JSON response with the error message.

    Args:
        route (Callable): The FastAPI route to be wrapped.

    Returns:
        Callable: The wrapped route.
    """

    @wraps(route)
    async def wrapped_route(*args, **kwargs):
        try:
            return await route(*args, **kwargs)
        except ContentTypeInvalid as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Invalid Content-Type header provided with request."
                },
            )
        except ContentTypeMissing as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={"message": "Content-Type header not provided with request."},
            )
        except InputImageLoadError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": f"Could not load input image. Cause: {error.get_public_error_details()}"
                },
            )
        except ModelInputError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": f"Error with model input. Cause: {error}",
                    "help_url": error.help_url,
                },
            )
        except InvalidModelIDError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={"message": "Invalid Model ID sent in request."},
            )
        except InvalidMaskDecodeArgument as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Invalid mask decode argument sent. tradeoff_factor must be in [0.0, 1.0], "
                    "mask_decode_mode: must be one of ['accurate', 'fast', 'tradeoff']"
                },
            )
        except MissingApiKeyError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": "Required Roboflow API key is missing. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except (
            WorkflowSyntaxError,
            InvalidReferenceTargetError,
            ExecutionGraphStructureError,
            StepInputDimensionalityError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=error.blocks_errors,
            )
            resp = JSONResponse(status_code=400, content=content.model_dump())
        except (
            WorkflowDefinitionError,
            ReferenceTypeError,
            RuntimeInputError,
            InvalidInputTypeError,
            OperationTypeNotRecognisedError,
            DynamicBlockError,
            WorkflowExecutionEngineVersionError,
            NotSupportedExecutionEngineError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "context": error.context,
                    "inner_error_type": error.inner_error_type,
                    "inner_error_message": str(error.inner_error),
                },
            )
        except (
            ProcessesManagerInvalidPayload,
            MalformedPayloadError,
            MessageToBigError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except (
            RoboflowAPINotAuthorizedError,
            ProcessesManagerAuthorisationError,
            UnauthorizedModelAccessError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=401,
                content={
                    "message": "Unauthorized access to roboflow API - check API key and make sure the key is valid for "
                    "workspace you use. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except PaymentRequiredError as error:
            logger.warning("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=402,
                content={
                    "message": "Not enough credits to perform this request. Verify your workspace billing page."
                },
            )
        except RoboflowAPIForbiddenError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=403,
                content={
                    "message": "Unauthorized access to roboflow API - check API key and make sure the key is valid and "
                    "have required scopes. Visit https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key "
                    "to learn how to retrieve one."
                },
            )
        except RoboflowAPIUsagePausedError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=423,
                content={
                    "message": "Roboflow API usage is paused. Please contact your workspace administrator to re-enable api keys."
                },
            )
        except (RoboflowAPINotNotFoundError, ModelNotFoundError) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=404,
                content={
                    "message": "Requested Roboflow resource not found. Make sure that workspace, project or model "
                    "you referred in request exists."
                },
            )
        except ProcessesManagerNotFoundError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=404,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except ModelPackageNegotiationError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Could not negotiate model package - {error}",
                    "help_url": error.help_url,
                },
            )
        except (
            InvalidEnvironmentVariableError,
            MissingServiceSecretError,
            ServiceConfigurationError,
            EnvironmentConfigurationError,
            InvalidEnvVariable,
            JetsonTypeResolutionError,
            MissingDependencyError,
            InvalidParameterError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500, content={"message": "Service misconfiguration."}
            )
        except (
            PreProcessingError,
            PostProcessingError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": "Model configuration related to pre- or post-processing is invalid."
                },
            )
        except ModelArtefactError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500, content={"message": "Model package is broken."}
            )
        except ModelLoadingError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Model loading failed: {error}",
                    "help_url": error.help_url,
                },
            )
        except (UntrustedFileError, FileHashSumMissmatch) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Issue with model package file: {error}",
                    "help_url": error.help_url,
                },
            )
        except ModelRetrievalError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": f"Could not retrieve model {error}",
                    "help_url": error.help_url,
                },
            )
        except OnnxProviderNotAvailable as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=501,
                content={
                    "message": "Could not find requested ONNX Runtime Provider. Check that you are using "
                    "the correct docker image on a supported device."
                },
            )
        except (
            MalformedRoboflowAPIResponseError,
            RoboflowAPIUnsuccessfulRequestError,
            WorkspaceLoadError,
            MalformedWorkflowResponseError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=502,
                content={"message": "Internal error. Request to Roboflow API failed."},
            )
        except InferenceModelNotFound as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={"message": "Model is temporarily not ready - retry request."},
                headers={"Retry-After": "1"},
            )
        except RoboflowAPIConnectionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={
                    "message": "Internal error. Could not connect to Roboflow API."
                },
            )
        except ModelManagerLockAcquisitionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=503,
                content={
                    "message": "Could not acquire model manager lock due to other request performing "
                    "blocking operation. Try again...."
                },
                headers={"Retry-After": "1"},
            )
        except RoboflowAPITimeoutError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=504,
                content={
                    "message": "Timeout when attempting to connect to Roboflow API."
                },
            )
        except ClientCausedStepExecutionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=[
                    WorkflowBlockError(
                        block_id=error.block_id,
                    ),
                ],
            )
            resp = JSONResponse(
                status_code=error.status_code,
                content=content.model_dump(),
            )
        except StepExecutionError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            content = WorkflowErrorResponse(
                message=str(error.public_message),
                error_type=error.__class__.__name__,
                context=str(error.context),
                inner_error_type=str(error.inner_error_type),
                inner_error_message=str(error.inner_error),
                blocks_errors=[
                    WorkflowBlockError(
                        block_id=error.block_id,
                        block_type=error.block_type,
                    ),
                ],
            )
            resp = JSONResponse(
                status_code=500,
                content=content.model_dump(),
            )
        except WorkflowError as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "context": error.context,
                    "inner_error_type": error.inner_error_type,
                    "inner_error_message": str(error.inner_error),
                },
            )
        except (
            ProcessesManagerClientError,
            CommunicationProtocolError,
        ) as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=500,
                content={
                    "message": error.public_message,
                    "error_type": error.__class__.__name__,
                    "inner_error_type": error.inner_error_type,
                },
            )
        except WebRTCConfigurationError as error:
            logger.error("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=400,
                content={
                    "message": str(error),
                    "error_type": "WebRTCConfigurationError",
                },
            )
        except CreditsExceededError as error:
            logger.error("%s: %s", type(error).__name__, error)
            resp = JSONResponse(
                status_code=402,
                content={
                    "message": "Not enough credits to perform this request.",
                    "error_type": "CreditsExceededError",
                },
            )
        except Exception as error:
            logger.exception("%s: %s", type(error).__name__, error)
            resp = JSONResponse(status_code=500, content={"message": "Internal error."})
        return resp

    return wrapped_route

inference.core.interfaces.http.http_api

Classes

HttpInterface

Bases: BaseInterface

Roboflow defined HTTP interface for a general-purpose inference server.

This class sets up the FastAPI application and adds necessary middleware, as well as initializes the model manager and model registry for the inference server.

Attributes:

Name Type Description
app FastAPI

The FastAPI application instance.

model_manager ModelManager

The manager for handling different models.

Source code in inference/core/interfaces/http/http_api.py
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
class HttpInterface(BaseInterface):
    """Roboflow defined HTTP interface for a general-purpose inference server.

    This class sets up the FastAPI application and adds necessary middleware,
    as well as initializes the model manager and model registry for the inference server.

    Attributes:
        app (FastAPI): The FastAPI application instance.
        model_manager (ModelManager): The manager for handling different models.
    """

    def __init__(
        self,
        model_manager: ModelManager,
        root_path: Optional[str] = None,
    ):
        """
        Initializes the HttpInterface with given model manager and model registry.

        Args:
            model_manager (ModelManager): The manager for handling different models.
            root_path (Optional[str]): The root path for the FastAPI application.

        Description:
            Deploy Roboflow trained models to nearly any compute environment!
        """

        description = "Roboflow inference server"

        app = FastAPI(
            title="Roboflow Inference Server",
            description=description,
            version=__version__,
            terms_of_service="https://roboflow.com/terms",
            contact={
                "name": "Roboflow Inc.",
                "url": "https://roboflow.com/contact",
                "email": "help@roboflow.com",
            },
            license_info={
                "name": "Apache 2.0",
                "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
            },
            root_path=root_path,
        )
        # Ensure in-memory logging is initialized as early as possible for all runtimes
        try:
            from inference.core.logging.memory_handler import setup_memory_logging

            setup_memory_logging()
        except Exception:
            pass

        app.mount(
            "/static",
            StaticFiles(directory="./inference/landing/out/static", html=True),
            name="static",
        )
        app.mount(
            "/_next/static",
            StaticFiles(directory="./inference/landing/out/_next/static", html=True),
            name="_next_static",
        )

        @app.on_event("shutdown")
        async def on_shutdown():
            logger.info("Shutting down %s", description)
            await usage_collector.async_push_usage_payloads()

        self._instrumentator = InferenceInstrumentator(
            app, model_manager=model_manager, endpoint="/metrics"
        )
        if LAMBDA:
            app.add_middleware(LambdaMiddleware)
        if GCP_SERVERLESS:
            app.add_middleware(GCPServerlessMiddleware)

        if len(ALLOW_ORIGINS) > 0:
            # Add CORS Middleware (but not for /build**, which is controlled separately)
            app.add_middleware(
                PathAwareCORSMiddleware,
                match_paths=r"^(?!/build).*",
                allow_origins=ALLOW_ORIGINS,
                allow_credentials=True,
                allow_methods=["*"],
                allow_headers=["*"],
                expose_headers=[
                    PROCESSING_TIME_HEADER,
                    REMOTE_PROCESSING_TIME_HEADER,
                    REMOTE_PROCESSING_TIMES_HEADER,
                    MODEL_COLD_START_HEADER,
                    MODEL_LOAD_TIME_HEADER,
                    MODEL_LOAD_DETAILS_HEADER,
                    MODEL_ID_HEADER,
                    WORKFLOW_ID_HEADER,
                    WORKSPACE_ID_HEADER,
                ]
                + ([EXECUTION_ID_HEADER] if EXECUTION_ID_HEADER is not None else []),
            )

        # Optionally add middleware for profiling the FastAPI server and underlying inference API code
        if PROFILE:
            app.add_middleware(
                CProfileMiddleware,
                enable=True,
                server_app=app,
                filename="/profile/output.pstats",
                strip_dirs=False,
                sort_by="cumulative",
            )
        if API_LOGGING_ENABLED:
            app.add_middleware(
                asgi_correlation_id.CorrelationIdMiddleware,
                header_name=CORRELATION_ID_HEADER,
                update_request_header=True,
                generator=lambda: uuid4().hex,
                validator=lambda a: True,
                transformer=lambda a: a,
            )
            if STRUCTURED_API_LOGGING:
                # Suppress uvicorn's default access log to avoid duplicate
                # unstructured entries — we replace it with a structured
                # access log middleware (see structured_access_log below).
                logging.getLogger("uvicorn.access").handlers = []
                logging.getLogger("uvicorn.access").propagate = False
        else:
            app.add_middleware(asgi_correlation_id.CorrelationIdMiddleware)

        if METRICS_ENABLED:

            @app.middleware("http")
            async def count_errors(request: Request, call_next):
                """Middleware to count errors.

                Args:
                    request (Request): The incoming request.
                    call_next (Callable): The next middleware or endpoint to call.

                Returns:
                    Response: The response from the next middleware or endpoint.
                """
                response = await call_next(request)
                if self.model_manager.pingback and response.status_code >= 400:
                    self.model_manager.num_errors += 1
                return response

        if not (LAMBDA or GCP_SERVERLESS):

            @app.get("/device/stats")
            def device_stats():
                not_configured_error_message = {
                    "error": "Device statistics endpoint is not enabled.",
                    "hint": "Mount the Docker socket and point its location when running the docker "
                    "container to collect device stats "
                    "(i.e. `docker run ... -v /var/run/docker.sock:/var/run/docker.sock "
                    "-e DOCKER_SOCKET_PATH=/var/run/docker.sock ...`).",
                }
                if not DOCKER_SOCKET_PATH:
                    return JSONResponse(
                        status_code=404,
                        content=not_configured_error_message,
                    )
                if not is_docker_socket_mounted(docker_socket_path=DOCKER_SOCKET_PATH):
                    return JSONResponse(
                        status_code=500,
                        content=not_configured_error_message,
                    )
                container_stats = get_container_stats(
                    docker_socket_path=DOCKER_SOCKET_PATH
                )
                return JSONResponse(status_code=200, content=container_stats)

        cached_api_keys = dict()

        if GCP_SERVERLESS:

            @app.middleware("http")
            async def check_authorization_serverless(request: Request, call_next):
                # exclusions
                skip_check = (
                    request.method not in ["GET", "POST"]
                    or request.url.path
                    in [
                        "/",
                        "/docs",
                        "/info",
                        "/healthz",  # health check endpoint for liveness probe
                        "/readiness",
                        "/metrics",
                        "/openapi.json",  # needed for /docs and /redoc
                        "/model/registry",  # dont auth this route, usually not used on serverlerless, but queue based serverless uses it internally (not accessible from outside)
                    ]
                    or request.url.path.startswith("/static/")
                    or request.url.path.startswith("/_next/")
                )

                # for these routes we only want to auth if dynamic python modules are provided
                if request.url.path in [
                    "/workflows/blocks/describe",
                    "/workflows/definition/schema",
                ]:
                    if request.method == "GET":
                        skip_check = True

                    elif (
                        get_content_type(request) == "application/json"
                        and int(request.headers.get("content-length", 0)) > 0
                    ):
                        json_params = await request.json()
                        dynamic_blocks_definitions = json_params.get(
                            "dynamic_blocks_definitions", None
                        )
                        if not dynamic_blocks_definitions:
                            skip_check = True

                if skip_check:
                    return await call_next(request)

                def _unauthorized_response(msg):
                    return JSONResponse(
                        status_code=401,
                        content={
                            "status": 401,
                            "message": msg,
                        },
                    )

                req_params = request.query_params
                json_params = dict()
                api_key = req_params.get("api_key", None)
                if (
                    api_key is None
                    and get_content_type(request) == "application/json"
                    and int(request.headers.get("content-length", 0)) > 0
                ):
                    # have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json
                    try:
                        json_params = await request.json()
                    except Exception:
                        pass
                api_key = json_params.get("api_key", api_key)

                if api_key is None:
                    return _unauthorized_response("Unauthorized api_key")

                cache_entry = cached_api_keys.get(api_key)
                workspace_id = None
                if cache_entry and cache_entry[0] >= time.time():
                    workspace_id = cache_entry[1]
                else:
                    try:
                        workspace_id = await get_roboflow_workspace_async(
                            api_key=api_key
                        )
                        cached_api_keys[api_key] = (
                            time.time() + 3600,
                            workspace_id,
                        )  # expired after 1 hour
                    except (RoboflowAPINotAuthorizedError, WorkspaceLoadError):
                        return _unauthorized_response("Unauthorized api_key")

                response = await call_next(request)
                if workspace_id:
                    response.headers[WORKSPACE_ID_HEADER] = workspace_id
                return response

        if DEDICATED_DEPLOYMENT_WORKSPACE_URL:

            @app.middleware("http")
            async def check_authorization(request: Request, call_next):
                # exclusions
                skip_check = (
                    request.method not in ["GET", "POST"]
                    or request.url.path
                    in [
                        "/",
                        "/docs",
                        "/redoc",
                        "/info",
                        "/healthz",  # health check endpoint for liveness probe
                        "/readiness",
                        "/metrics",
                        "/openapi.json",  # needed for /docs and /redoc
                    ]
                    or request.url.path.startswith("/static/")
                    or request.url.path.startswith("/_next/")
                )
                if skip_check:
                    return await call_next(request)

                def _unauthorized_response(msg):
                    return JSONResponse(
                        status_code=401,
                        content={
                            "status": 401,
                            "message": msg,
                        },
                    )

                # check api_key
                req_params = request.query_params
                json_params = dict()
                api_key = req_params.get("api_key", None)
                if (
                    api_key is None
                    and get_content_type(request) == "application/json"
                    and int(request.headers.get("content-length", 0)) > 0
                ):
                    # have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json
                    try:
                        json_params = await request.json()
                    except Exception:
                        pass
                api_key = json_params.get("api_key", api_key)

                if api_key is None:
                    return _unauthorized_response("Unauthorized api_key")

                cache_entry = cached_api_keys.get(api_key)
                workspace_id = None
                if cache_entry and cache_entry[0] >= time.time():
                    workspace_id = cache_entry[1]
                else:
                    try:
                        if api_key is None:
                            workspace_id = None
                        else:
                            workspace_id = await get_roboflow_workspace_async(
                                api_key=api_key
                            )

                        if workspace_id != DEDICATED_DEPLOYMENT_WORKSPACE_URL:
                            return _unauthorized_response("Unauthorized api_key")

                        cached_api_keys[api_key] = (
                            time.time() + 3600,
                            workspace_id,
                        )  # expired after 1 hour
                    except (RoboflowAPINotAuthorizedError, WorkspaceLoadError):
                        return _unauthorized_response("Unauthorized api_key")

                response = await call_next(request)
                if workspace_id:
                    response.headers[WORKSPACE_ID_HEADER] = workspace_id
                return response

        @app.middleware("http")
        async def add_inference_engine_headers(request: Request, call_next):
            response = await call_next(request)
            inference_engine = (
                "inference-models" if USE_INFERENCE_MODELS else "old-inference"
            )
            response.headers["x-inference-engine"] = inference_engine
            return response

        @app.middleware("http")
        async def track_model_load(request: Request, call_next):
            load_collector = ModelLoadCollector()
            model_load_info.set(load_collector)
            ids_collector = RequestModelIds()
            request_model_ids.set(ids_collector)
            response = await call_next(request)
            if load_collector.has_data():
                total, detail = load_collector.summarize()
                response.headers[MODEL_COLD_START_HEADER] = "true"
                response.headers[MODEL_LOAD_TIME_HEADER] = str(total)
                if detail is not None:
                    response.headers[MODEL_LOAD_DETAILS_HEADER] = detail
            else:
                response.headers[MODEL_COLD_START_HEADER] = "false"
            model_ids = ids_collector.get_ids()
            if model_ids:
                response.headers[MODEL_ID_HEADER] = ",".join(sorted(model_ids))
            wf_id = request_workflow_id.get(None)
            if wf_id:
                response.headers[WORKFLOW_ID_HEADER] = wf_id
            return response

        if API_LOGGING_ENABLED and STRUCTURED_API_LOGGING:

            @app.middleware("http")
            async def structured_access_log(request: Request, call_next):
                response = await call_next(request)
                log_fields = {
                    "method": request.method,
                    "path": request.url.path,
                    "status_code": response.status_code,
                }

                # Read request_id and execution_id from response headers
                # instead of ContextVars — @app.middleware("http") uses
                # BaseHTTPMiddleware which runs the inner chain in a
                # separate asyncio task, so ContextVars set by inner
                # middlewares are not visible here.
                header_fields = {
                    "request_id": CORRELATION_ID_HEADER,
                    "processing_time": PROCESSING_TIME_HEADER,
                    "model_cold_start": MODEL_COLD_START_HEADER,
                    "model_load_time": MODEL_LOAD_TIME_HEADER,
                    "model_id": MODEL_ID_HEADER,
                    "workflow_id": WORKFLOW_ID_HEADER,
                    "workspace_id": WORKSPACE_ID_HEADER,
                }
                if EXECUTION_ID_HEADER is not None:
                    header_fields["execution_id"] = EXECUTION_ID_HEADER
                for field_name, header_name in header_fields.items():
                    value = response.headers.get(header_name)
                    if value is not None:
                        log_fields[field_name] = value

                logger.info(
                    f"{request.method} {request.url.path} {response.status_code}",
                    **log_fields,
                )
                return response

        self.app = app
        self.model_manager = model_manager
        self.stream_manager_client: Optional[StreamManagerClient] = None
        self.shared_thread_pool_executor: Optional[ThreadPoolExecutor] = None
        if HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED:
            self.shared_thread_pool_executor = ThreadPoolExecutor(
                max_workers=HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_WORKERS
            )

        if ENABLE_STREAM_API:
            operations_timeout = os.getenv("STREAM_MANAGER_OPERATIONS_TIMEOUT")
            if operations_timeout is not None:
                operations_timeout = float(operations_timeout)
            self.stream_manager_client = StreamManagerClient.init(
                host=os.getenv("STREAM_MANAGER_HOST", "127.0.0.1"),
                port=int(os.getenv("STREAM_MANAGER_PORT", "7070")),
                operations_timeout=operations_timeout,
            )
            self._instrumentator.set_stream_manager_client(self.stream_manager_client)

        def process_inference_request(
            inference_request: InferenceRequest,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
            **kwargs,
        ) -> InferenceResponse:
            """Processes an inference request by calling the appropriate model.

            Args:
                inference_request (InferenceRequest): The request containing model ID and other inference details.
                countinference (Optional[bool]): Whether to count inference for usage.
                service_secret (Optional[str]): The service secret.

            Returns:
                InferenceResponse: The response containing the inference results.
            """
            de_aliased_model_id = resolve_roboflow_model_alias(
                model_id=inference_request.model_id
            )
            self.model_manager.add_model(
                de_aliased_model_id,
                inference_request.api_key,
                countinference=countinference,
                service_secret=service_secret,
            )
            resp = self.model_manager.infer_from_request_sync(
                de_aliased_model_id, inference_request, **kwargs
            )
            return orjson_response(resp)

        def process_workflow_inference_request(
            workflow_request: WorkflowInferenceRequest,
            workflow_specification: dict,
            background_tasks: Optional[BackgroundTasks],
            profiler: WorkflowsProfiler,
        ) -> WorkflowInferenceResponse:
            if workflow_request.workflow_id:
                request_workflow_id.set(workflow_request.workflow_id)

            workflow_init_parameters = {
                "workflows_core.model_manager": model_manager,
                "workflows_core.api_key": workflow_request.api_key,
                "workflows_core.background_tasks": background_tasks,
            }
            execution_engine = ExecutionEngine.init(
                workflow_definition=workflow_specification,
                init_parameters=workflow_init_parameters,
                max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
                prevent_local_images_loading=True,
                profiler=profiler,
                executor=self.shared_thread_pool_executor,
                workflow_id=workflow_request.workflow_id,
            )
            is_preview = False
            if hasattr(workflow_request, "is_preview"):
                is_preview = workflow_request.is_preview
            workflow_results = execution_engine.run(
                runtime_parameters=workflow_request.inputs,
                serialize_results=True,
                _is_preview=is_preview,
            )
            with profiler.profile_execution_phase(
                name="workflow_results_filtering",
                categories=["inference_package_operation"],
            ):
                outputs = filter_out_unwanted_workflow_outputs(
                    workflow_results=workflow_results,
                    excluded_fields=workflow_request.excluded_fields,
                )
            profiler_trace = profiler.export_trace()
            response = WorkflowInferenceResponse(
                outputs=outputs,
                profiler_trace=profiler_trace,
            )
            return orjson_response(response=response)

        def load_core_model(
            inference_request: InferenceRequest,
            api_key: Optional[str] = None,
            core_model: str = None,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ) -> None:
            """Loads a core model (e.g., "clip" or "sam") into the model manager.

            Args:
                inference_request (InferenceRequest): The request containing version and other details.
                api_key (Optional[str]): The API key for the request.
                core_model (str): The core model type, e.g., "clip" or "sam".
                countinference (Optional[bool]): Whether to count inference or not.
                service_secret (Optional[str]): The service secret for the request.

            Returns:
                str: The core model ID.
            """
            if api_key:
                inference_request.api_key = api_key
            version_id_field = f"{core_model}_version_id"
            core_model_id = (
                f"{core_model}/{inference_request.__getattribute__(version_id_field)}"
            )
            self.model_manager.add_model(
                core_model_id,
                inference_request.api_key,
                endpoint_type=ModelEndpointType.CORE_MODEL,
                countinference=countinference,
                service_secret=service_secret,
            )
            return core_model_id

        load_clip_model = partial(load_core_model, core_model="clip")
        """Loads the CLIP model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The CLIP model ID.
        """

        load_pe_model = partial(load_core_model, core_model="perception_encoder")
        """Loads the Perception Encoder model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The Perception Encoder model ID.
        """

        load_sam_model = partial(load_core_model, core_model="sam")
        """Loads the SAM model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The SAM model ID.
        """
        load_sam2_model = partial(load_core_model, core_model="sam2")
        """Loads the SAM2 model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The SAM2 model ID.
        """

        load_gaze_model = partial(load_core_model, core_model="gaze")
        """Loads the GAZE model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The GAZE model ID.
        """

        load_doctr_model = partial(load_core_model, core_model="doctr")
        """Loads the DocTR model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The DocTR model ID.
        """

        load_easy_ocr_model = partial(load_core_model, core_model="easy_ocr")
        """Loads the EasyOCR model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The EasyOCR model ID.
        """

        load_paligemma_model = partial(load_core_model, core_model="paligemma")

        load_grounding_dino_model = partial(
            load_core_model, core_model="grounding_dino"
        )
        """Loads the Grounding DINO model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The Grounding DINO model ID.
        """

        load_yolo_world_model = partial(load_core_model, core_model="yolo_world")
        load_owlv2_model = partial(load_core_model, core_model="owlv2")
        """Loads the YOLO World model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The YOLO World model ID.
        """

        load_trocr_model = partial(load_core_model, core_model="trocr")
        """Loads the TrOCR model into the model manager.

        Args:
        Same as `load_core_model`.

        Returns:
        The TrOCR model ID.
        """

        @app.get(
            "/info",
            response_model=ServerVersionInfo,
            summary="Info",
            description="Get the server name and version number",
        )
        def root():
            """Endpoint to get the server name and version number.

            Returns:
                ServerVersionInfo: The server version information.
            """
            return ServerVersionInfo(
                name="Roboflow Inference Server",
                version=__version__,
                uuid=GLOBAL_INFERENCE_SERVER_ID,
            )

        @app.get(
            "/logs",
            summary="Get Recent Logs",
            description="Get recent application logs for debugging",
        )
        def get_logs(
            limit: Optional[int] = Query(
                100, description="Maximum number of log entries to return"
            ),
            level: Optional[str] = Query(
                None,
                description="Filter by log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
            ),
            since: Optional[str] = Query(
                None, description="Return logs since this ISO timestamp"
            ),
        ):
            """Get recent application logs from memory.

            Only available when ENABLE_IN_MEMORY_LOGS environment variable is set to 'true'.

            Args:
                limit: Maximum number of log entries (default 100)
                level: Filter by log level
                since: ISO timestamp to filter logs since

            Returns:
                List of log entries with timestamp, level, logger, and message
            """
            # Check if in-memory logging is enabled
            from inference.core.logging.memory_handler import (
                get_recent_logs,
                is_memory_logging_enabled,
            )

            if not is_memory_logging_enabled():
                raise HTTPException(
                    status_code=404, detail="Logs endpoint not available"
                )

            try:
                logs = get_recent_logs(limit=limit or 100, level=level, since=since)
                return {"logs": logs, "total_count": len(logs)}
            except (ImportError, ModuleNotFoundError):
                raise HTTPException(
                    status_code=500, detail="Logging system not properly initialized"
                )

        if not LAMBDA and GET_MODEL_REGISTRY_ENABLED:

            @app.get(
                "/model/registry",
                response_model=ModelsDescriptions,
                summary="Get model keys",
                description="Get the ID of each loaded model",
            )
            def registry():
                """Get the ID of each loaded model in the registry.

                Returns:
                    ModelsDescriptions: The object containing models descriptions
                """
                logger.debug(f"Reached /model/registry")
                models_descriptions = self.model_manager.describe_models()
                return ModelsDescriptions.from_models_descriptions(
                    models_descriptions=models_descriptions
                )

        # The current AWS Lambda authorizer only supports path parameters, therefore we can only use the legacy infer route. This case statement excludes routes which won't work for the current Lambda authorizer.
        if not (LAMBDA or GCP_SERVERLESS):

            @app.post(
                "/model/add",
                response_model=ModelsDescriptions,
                summary="Load a model",
                description="Load the model with the given model ID",
            )
            @with_route_exceptions
            def model_add(
                request: AddModelRequest,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Load the model with the given model ID into the model manager.

                Args:
                    request (AddModelRequest): The request containing the model ID and optional API key.
                    countinference (Optional[bool]): Whether to count inference or not.
                    service_secret (Optional[str]): The service secret for the request.

                Returns:
                    ModelsDescriptions: The object containing models descriptions
                """
                logger.debug(f"Reached /model/add")
                de_aliased_model_id = resolve_roboflow_model_alias(
                    model_id=request.model_id
                )
                logger.info(f"Loading model: {de_aliased_model_id}")
                self.model_manager.add_model(
                    de_aliased_model_id,
                    request.api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                models_descriptions = self.model_manager.describe_models()
                return ModelsDescriptions.from_models_descriptions(
                    models_descriptions=models_descriptions
                )

            @app.post(
                "/model/remove",
                response_model=ModelsDescriptions,
                summary="Remove a model",
                description="Remove the model with the given model ID",
            )
            @with_route_exceptions
            def model_remove(request: ClearModelRequest):
                """Remove the model with the given model ID from the model manager.

                Args:
                    request (ClearModelRequest): The request containing the model ID to be removed.

                Returns:
                    ModelsDescriptions: The object containing models descriptions
                """
                logger.debug(f"Reached /model/remove")
                de_aliased_model_id = resolve_roboflow_model_alias(
                    model_id=request.model_id
                )
                self.model_manager.remove(de_aliased_model_id)
                models_descriptions = self.model_manager.describe_models()
                return ModelsDescriptions.from_models_descriptions(
                    models_descriptions=models_descriptions
                )

            @app.post(
                "/model/clear",
                response_model=ModelsDescriptions,
                summary="Remove all models",
                description="Remove all loaded models",
            )
            @with_route_exceptions
            def model_clear():
                """Remove all loaded models from the model manager.

                Returns:
                    ModelsDescriptions: The object containing models descriptions
                """
                logger.debug(f"Reached /model/clear")
                self.model_manager.clear()
                models_descriptions = self.model_manager.describe_models()
                return ModelsDescriptions.from_models_descriptions(
                    models_descriptions=models_descriptions
                )

        # these NEW endpoints need authentication protection
        if not LAMBDA and not GCP_SERVERLESS:

            @app.post(
                "/infer/object_detection",
                response_model=Union[
                    ObjectDetectionInferenceResponse,
                    List[ObjectDetectionInferenceResponse],
                    StubResponse,
                ],
                summary="Object detection infer",
                description="Run inference with the specified object detection model",
                response_model_exclude_none=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_object_detection(
                inference_request: ObjectDetectionInferenceRequest,
                background_tasks: BackgroundTasks,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified object detection model.

                Args:
                    inference_request (ObjectDetectionInferenceRequest): The request containing the necessary details for object detection.
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks

                Returns:
                    Union[ObjectDetectionInferenceResponse, List[ObjectDetectionInferenceResponse]]: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/object_detection")
                return process_inference_request(
                    inference_request,
                    active_learning_eligible=True,
                    background_tasks=background_tasks,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            @app.post(
                "/infer/instance_segmentation",
                response_model=Union[
                    InstanceSegmentationInferenceResponse, StubResponse
                ],
                summary="Instance segmentation infer",
                description="Run inference with the specified instance segmentation model",
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_instance_segmentation(
                inference_request: InstanceSegmentationInferenceRequest,
                background_tasks: BackgroundTasks,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified instance segmentation model.

                Args:
                    inference_request (InstanceSegmentationInferenceRequest): The request containing the necessary details for instance segmentation.
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks

                Returns:
                    InstanceSegmentationInferenceResponse: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/instance_segmentation")
                return process_inference_request(
                    inference_request,
                    active_learning_eligible=True,
                    background_tasks=background_tasks,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            @app.post(
                "/infer/semantic_segmentation",
                response_model=Union[
                    SemanticSegmentationInferenceResponse, StubResponse
                ],
                summary="Semantic segmentation infer",
                description="Run inference with the specified semantic segmentation model",
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_semantic_segmentation(
                inference_request: SemanticSegmentationInferenceRequest,
                background_tasks: BackgroundTasks,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified semantic segmentation model.

                Args:
                    inference_request (SemanticSegmentationInferenceRequest): The request containing the necessary details for semantic segmentation.
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks

                Returns:
                    SemanticSegmentationInferenceResponse: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/semantic_segmentation")
                return process_inference_request(
                    inference_request,
                    active_learning_eligible=True,
                    background_tasks=background_tasks,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            @app.post(
                "/infer/classification",
                response_model=Union[
                    ClassificationInferenceResponse,
                    MultiLabelClassificationInferenceResponse,
                    StubResponse,
                ],
                summary="Classification infer",
                description="Run inference with the specified classification model",
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_classification(
                inference_request: ClassificationInferenceRequest,
                background_tasks: BackgroundTasks,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified classification model.

                Args:
                    inference_request (ClassificationInferenceRequest): The request containing the necessary details for classification.
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks

                Returns:
                    Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/classification")
                return process_inference_request(
                    inference_request,
                    active_learning_eligible=True,
                    background_tasks=background_tasks,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            @app.post(
                "/infer/keypoints_detection",
                response_model=Union[KeypointsDetectionInferenceResponse, StubResponse],
                summary="Keypoints detection infer",
                description="Run inference with the specified keypoints detection model",
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_keypoints(
                inference_request: KeypointsDetectionInferenceRequest,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified keypoints detection model.

                Args:
                    inference_request (KeypointsDetectionInferenceRequest): The request containing the necessary details for keypoints detection.

                Returns:
                    Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/keypoints_detection")
                return process_inference_request(
                    inference_request,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            if LMM_ENABLED or MOONDREAM2_ENABLED:

                @app.post(
                    "/infer/lmm",
                    response_model=Union[
                        LMMInferenceResponse,
                        List[LMMInferenceResponse],
                        StubResponse,
                    ],
                    summary="Large multi-modal model infer",
                    description="Run inference with the specified large multi-modal model",
                    response_model_exclude_none=True,
                )
                @with_route_exceptions
                @usage_collector("request")
                def infer_lmm(
                    inference_request: LMMInferenceRequest,
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """Run inference with the specified large multi-modal model.

                    Args:
                        inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference.

                    Returns:
                        Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results.
                    """
                    logger.debug(f"Reached /infer/lmm")
                    return process_inference_request(
                        inference_request,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

                @app.post(
                    "/infer/lmm/{model_id:path}",
                    response_model=Union[
                        LMMInferenceResponse,
                        List[LMMInferenceResponse],
                        StubResponse,
                    ],
                    summary="Large multi-modal model infer with model ID in path",
                    description="Run inference with the specified large multi-modal model. Model ID is specified in the URL path (can contain slashes).",
                    response_model_exclude_none=True,
                )
                @with_route_exceptions
                @usage_collector("request")
                def infer_lmm_with_model_id(
                    model_id: str,
                    inference_request: LMMInferenceRequest,
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """Run inference with the specified large multi-modal model.

                    The model_id can be specified in the URL path. If model_id is also provided
                    in the request body, it must match the path parameter.

                    Args:
                        model_id (str): The model identifier from the URL path.
                        inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference.

                    Returns:
                        Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results.

                    Raises:
                        HTTPException: If model_id in path and request body don't match.
                    """
                    logger.debug(f"Reached /infer/lmm/{model_id}")

                    # Validate model_id consistency between path and request body
                    if (
                        inference_request.model_id is not None
                        and inference_request.model_id != model_id
                    ):
                        raise HTTPException(
                            status_code=400,
                            detail=f"Model ID mismatch: path specifies '{model_id}' but request body specifies '{inference_request.model_id}'",
                        )

                    # Set the model_id from path if not in request body
                    inference_request.model_id = model_id

                    return process_inference_request(
                        inference_request,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

        if not DISABLE_WORKFLOW_ENDPOINTS:

            @app.post(
                "/{workspace_name}/workflows/{workflow_id}/describe_interface",
                response_model=DescribeInterfaceResponse,
                summary="Endpoint to describe interface of predefined workflow",
                description="Checks Roboflow API for workflow definition, once acquired - describes workflow inputs and outputs",
            )
            @with_route_exceptions
            def describe_predefined_workflow_interface(
                workspace_name: str,
                workflow_id: str,
                workflow_request: PredefinedWorkflowDescribeInterfaceRequest,
            ) -> DescribeInterfaceResponse:
                workflow_specification = get_workflow_specification(
                    api_key=workflow_request.api_key,
                    workspace_id=workspace_name,
                    workflow_id=workflow_id,
                    use_cache=workflow_request.use_cache,
                    workflow_version_id=workflow_request.workflow_version_id,
                )
                return handle_describe_workflows_interface(
                    definition=workflow_specification,
                )

            @app.post(
                "/workflows/describe_interface",
                response_model=DescribeInterfaceResponse,
                summary="Endpoint to describe interface of workflow given in request",
                description="Parses workflow definition and retrieves describes inputs and outputs",
            )
            @with_route_exceptions
            def describe_workflow_interface(
                workflow_request: WorkflowSpecificationDescribeInterfaceRequest,
            ) -> DescribeInterfaceResponse:
                return handle_describe_workflows_interface(
                    definition=workflow_request.specification,
                )

            @app.post(
                "/{workspace_name}/workflows/{workflow_id}",
                response_model=WorkflowInferenceResponse,
                summary="Endpoint to run predefined workflow",
                description="Checks Roboflow API for workflow definition, once acquired - parses and executes injecting runtime parameters from request body",
            )
            @app.post(
                "/infer/workflows/{workspace_name}/{workflow_id}",
                response_model=WorkflowInferenceResponse,
                summary="[LEGACY] Endpoint to run predefined workflow",
                description="Checks Roboflow API for workflow definition, once acquired - parses and executes injecting runtime parameters from request body. This endpoint is deprecated and will be removed end of Q2 2024",
                deprecated=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_from_predefined_workflow(
                workspace_name: str,
                workflow_id: str,
                workflow_request: PredefinedWorkflowInferenceRequest,
                background_tasks: BackgroundTasks,
            ) -> WorkflowInferenceResponse:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                if ENABLE_WORKFLOWS_PROFILING and workflow_request.enable_profiling:
                    profiler = BaseWorkflowsProfiler.init(
                        max_runs_in_buffer=WORKFLOWS_PROFILER_BUFFER_SIZE,
                    )
                else:
                    profiler = NullWorkflowsProfiler.init()
                with profiler.profile_execution_phase(
                    name="workflow_definition_fetching",
                    categories=["inference_package_operation"],
                ):
                    workflow_specification = get_workflow_specification(
                        api_key=workflow_request.api_key,
                        workspace_id=workspace_name,
                        workflow_id=workflow_id,
                        use_cache=workflow_request.use_cache,
                        workflow_version_id=workflow_request.workflow_version_id,
                    )
                if not workflow_request.workflow_id:
                    workflow_request.workflow_id = workflow_id
                if not workflow_specification.get("id"):
                    logger.warning(
                        "Internal workflow ID missing in specification for '%s'",
                        workflow_id,
                    )
                return process_workflow_inference_request(
                    workflow_request=workflow_request,
                    workflow_specification=workflow_specification,
                    background_tasks=(
                        background_tasks if not (LAMBDA or GCP_SERVERLESS) else None
                    ),
                    profiler=profiler,
                )

            @app.post(
                "/workflows/run",
                response_model=WorkflowInferenceResponse,
                summary="Endpoint to run workflow specification provided in payload",
                description="Parses and executes workflow specification, injecting runtime parameters from request body.",
            )
            @app.post(
                "/infer/workflows",
                response_model=WorkflowInferenceResponse,
                summary="[LEGACY] Endpoint to run workflow specification provided in payload",
                description="Parses and executes workflow specification, injecting runtime parameters from request body. This endpoint is deprecated and will be removed end of Q2 2024.",
                deprecated=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_from_workflow(
                workflow_request: WorkflowSpecificationInferenceRequest,
                background_tasks: BackgroundTasks,
            ) -> WorkflowInferenceResponse:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                if ENABLE_WORKFLOWS_PROFILING and workflow_request.enable_profiling:
                    profiler = BaseWorkflowsProfiler.init(
                        max_runs_in_buffer=WORKFLOWS_PROFILER_BUFFER_SIZE,
                    )
                else:
                    profiler = NullWorkflowsProfiler.init()
                return process_workflow_inference_request(
                    workflow_request=workflow_request,
                    workflow_specification=workflow_request.specification,
                    background_tasks=(
                        background_tasks if not (LAMBDA or GCP_SERVERLESS) else None
                    ),
                    profiler=profiler,
                )

            @app.get(
                "/workflows/execution_engine/versions",
                response_model=ExecutionEngineVersions,
                summary="Returns available Execution Engine versions sorted from oldest to newest",
                description="Returns available Execution Engine versions sorted from oldest to newest",
            )
            @with_route_exceptions
            def get_execution_engine_versions() -> ExecutionEngineVersions:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                versions = get_available_versions()
                return ExecutionEngineVersions(versions=versions)

            @app.get(
                "/workflows/blocks/describe",
                response_model=WorkflowsBlocksDescription,
                summary="[LEGACY] Endpoint to get definition of workflows blocks that are accessible",
                description="Endpoint provides detailed information about workflows building blocks that are "
                "accessible in the inference server. This information could be used to programmatically "
                "build / display workflows.",
                deprecated=True,
            )
            @with_route_exceptions
            def describe_workflows_blocks(
                request: Request,
            ) -> Union[WorkflowsBlocksDescription, Response]:
                result = handle_describe_workflows_blocks_request()
                return gzip_response_if_requested(request=request, response=result)

            @app.post(
                "/workflows/blocks/describe",
                response_model=WorkflowsBlocksDescription,
                summary="[EXPERIMENTAL] Endpoint to get definition of workflows blocks that are accessible",
                description="Endpoint provides detailed information about workflows building blocks that are "
                "accessible in the inference server. This information could be used to programmatically "
                "build / display workflows. Additionally - in request body one can specify list of "
                "dynamic blocks definitions which will be transformed into blocks and used to generate "
                "schemas and definitions of connections",
            )
            @with_route_exceptions
            def describe_workflows_blocks(
                request: Request,
                request_payload: Optional[DescribeBlocksRequest] = None,
            ) -> Union[WorkflowsBlocksDescription, Response]:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                dynamic_blocks_definitions = None
                requested_execution_engine_version = None
                api_key = None
                if request_payload is not None:
                    dynamic_blocks_definitions = (
                        request_payload.dynamic_blocks_definitions
                    )
                    requested_execution_engine_version = (
                        request_payload.execution_engine_version
                    )
                    api_key = request_payload.api_key or request.query_params.get(
                        "api_key", None
                    )
                result = handle_describe_workflows_blocks_request(
                    dynamic_blocks_definitions=dynamic_blocks_definitions,
                    requested_execution_engine_version=requested_execution_engine_version,
                    api_key=api_key,
                )
                return gzip_response_if_requested(request=request, response=result)

            @app.get(
                "/workflows/definition/schema",
                response_model=WorkflowsBlocksSchemaDescription,
                summary="Endpoint to fetch the workflows block schema",
                description="Endpoint to fetch the schema of all available blocks. This information can be "
                "used to validate workflow definitions and suggest syntax in the JSON editor.",
            )
            @with_route_exceptions
            def get_workflow_schema(
                request: Request,
            ) -> WorkflowsBlocksSchemaDescription:
                result = get_workflow_schema_description()
                return gzip_response_if_requested(request, response=result)

            @app.post(
                "/workflows/blocks/dynamic_outputs",
                response_model=List[OutputDefinition],
                summary="[EXPERIMENTAL] Endpoint to get definition of dynamic output for workflow step",
                description="Endpoint to be used when step outputs can be discovered only after "
                "filling manifest with data.",
            )
            @with_route_exceptions
            def get_dynamic_block_outputs(
                step_manifest: Dict[str, Any],
            ) -> List[OutputDefinition]:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                # Potentially TODO: dynamic blocks do not support dynamic outputs, but if it changes
                # we need to provide dynamic blocks manifests here
                dummy_workflow_definition = {
                    "version": "1.0",
                    "inputs": [],
                    "steps": [step_manifest],
                    "outputs": [],
                }
                available_blocks = load_workflow_blocks()
                parsed_definition = parse_workflow_definition(
                    raw_workflow_definition=dummy_workflow_definition,
                    available_blocks=available_blocks,
                )
                parsed_manifest = parsed_definition.steps[0]
                return parsed_manifest.get_actual_outputs()

            @app.post(
                "/workflows/validate",
                response_model=WorkflowValidationStatus,
                summary="[EXPERIMENTAL] Endpoint to validate",
                description="Endpoint provides a way to check validity of JSON workflow definition.",
            )
            @with_route_exceptions
            def validate_workflow(
                specification: dict,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
            ) -> WorkflowValidationStatus:
                # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
                step_execution_mode = StepExecutionMode(WORKFLOWS_STEP_EXECUTION_MODE)
                workflow_init_parameters = {
                    "workflows_core.model_manager": model_manager,
                    "workflows_core.api_key": api_key,
                    "workflows_core.background_tasks": None,
                    "workflows_core.step_execution_mode": step_execution_mode,
                }
                _ = ExecutionEngine.init(
                    workflow_definition=specification,
                    init_parameters=workflow_init_parameters,
                    max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
                    prevent_local_images_loading=True,
                )
                return WorkflowValidationStatus(status="ok")

        if WEBRTC_WORKER_ENABLED:

            @app.post(
                "/initialise_webrtc_worker",
                response_model=InitializeWebRTCResponse,
                summary="[EXPERIMENTAL] Establishes WebRTC peer connection and processes video stream in spawned process or modal function",
                description="[EXPERIMENTAL] Establishes WebRTC peer connection and processes video stream in spawned process or modal function",
            )
            @with_route_exceptions_async
            async def initialise_webrtc_worker(
                request: WebRTCWorkerRequest,
                r: Request,
            ) -> InitializeWebRTCResponse:
                if str(r.headers.get("origin")).lower() == BUILDER_ORIGIN.lower():
                    if re.search(
                        r"^https://[^.]+\.roboflow\.[^./]+/", str(r.url).lower()
                    ):
                        request.is_preview = True

                logger.debug("Received initialise_webrtc_worker request")
                worker_result: WebRTCWorkerResult = await start_worker(
                    webrtc_request=request,
                )
                if worker_result.exception_type is not None:
                    if worker_result.exception_type == "WorkflowSyntaxError":
                        raise WorkflowSyntaxError(
                            public_message=worker_result.error_message,
                            context=worker_result.error_context,
                            inner_error=worker_result.inner_error,
                        )
                    if worker_result.exception_type == "WorkflowError":
                        raise WorkflowError(
                            public_message=worker_result.error_message,
                            context=worker_result.error_context,
                        )
                    expected_exceptions = {
                        "Exception": Exception,
                        "KeyError": KeyError,
                        "MissingApiKeyError": MissingApiKeyError,
                        "NotImplementedError": NotImplementedError,
                        "RoboflowAPINotAuthorizedError": RoboflowAPINotAuthorizedError,
                        "RoboflowAPINotNotFoundError": RoboflowAPINotNotFoundError,
                        "ValidationError": ValidationError,
                        "WebRTCConfigurationError": WebRTCConfigurationError,
                    }
                    exc = expected_exceptions.get(
                        worker_result.exception_type, Exception
                    )(worker_result.error_message)
                    logger.error(
                        f"Initialise webrtc worker failed with %s: %s",
                        worker_result.exception_type,
                        worker_result.error_message,
                    )
                    raise exc
                logger.debug("Returning initialise_webrtc_worker response")
                return InitializeWebRTCResponse(
                    context=CommandContext(),
                    status=OperationStatus.SUCCESS,
                    sdp=worker_result.answer.sdp,
                    type=worker_result.answer.type,
                )

        if ENABLE_STREAM_API:

            @app.get(
                "/inference_pipelines/list",
                response_model=ListPipelinesResponse,
                summary="[EXPERIMENTAL] List active InferencePipelines",
                description="[EXPERIMENTAL] Listing all active InferencePipelines processing videos",
            )
            @with_route_exceptions_async
            async def list_pipelines(_: Request) -> ListPipelinesResponse:
                return await self.stream_manager_client.list_pipelines()

            @app.get(
                "/inference_pipelines/{pipeline_id}/status",
                response_model=InferencePipelineStatusResponse,
                summary="[EXPERIMENTAL] Get status of InferencePipeline",
                description="[EXPERIMENTAL] Get status of InferencePipeline",
            )
            @with_route_exceptions_async
            async def get_status(pipeline_id: str) -> InferencePipelineStatusResponse:
                return await self.stream_manager_client.get_status(
                    pipeline_id=pipeline_id
                )

            @app.post(
                "/inference_pipelines/initialise",
                response_model=CommandResponse,
                summary="[EXPERIMENTAL] Starts new InferencePipeline",
                description="[EXPERIMENTAL] Starts new InferencePipeline",
            )
            @with_route_exceptions_async
            async def initialise(request: InitialisePipelinePayload) -> CommandResponse:
                return await self.stream_manager_client.initialise_pipeline(
                    initialisation_request=request
                )

            @app.post(
                "/inference_pipelines/initialise_webrtc",
                response_model=InitializeWebRTCPipelineResponse,
                summary="[EXPERIMENTAL] Establishes WebRTC peer connection and starts new InferencePipeline consuming video track",
                description="[EXPERIMENTAL] Establishes WebRTC peer connection and starts new InferencePipeline consuming video track",
            )
            @with_route_exceptions_async
            async def initialise_webrtc_inference_pipeline(
                request: InitialiseWebRTCPipelinePayload,
            ) -> CommandResponse:
                logger.debug("Received initialise webrtc inference pipeline request")
                resp = await self.stream_manager_client.initialise_webrtc_pipeline(
                    initialisation_request=request
                )
                logger.debug("Returning initialise webrtc inference pipeline response")
                return resp

            @app.post(
                "/inference_pipelines/{pipeline_id}/pause",
                response_model=CommandResponse,
                summary="[EXPERIMENTAL] Pauses the InferencePipeline",
                description="[EXPERIMENTAL] Pauses the InferencePipeline",
            )
            @with_route_exceptions_async
            async def pause(pipeline_id: str) -> CommandResponse:
                return await self.stream_manager_client.pause_pipeline(
                    pipeline_id=pipeline_id
                )

            @app.post(
                "/inference_pipelines/{pipeline_id}/resume",
                response_model=CommandResponse,
                summary="[EXPERIMENTAL] Resumes the InferencePipeline",
                description="[EXPERIMENTAL] Resumes the InferencePipeline",
            )
            @with_route_exceptions_async
            async def resume(pipeline_id: str) -> CommandResponse:
                return await self.stream_manager_client.resume_pipeline(
                    pipeline_id=pipeline_id
                )

            @app.post(
                "/inference_pipelines/{pipeline_id}/terminate",
                response_model=CommandResponse,
                summary="[EXPERIMENTAL] Terminates the InferencePipeline",
                description="[EXPERIMENTAL] Terminates the InferencePipeline",
            )
            @with_route_exceptions_async
            async def terminate(pipeline_id: str) -> CommandResponse:
                return await self.stream_manager_client.terminate_pipeline(
                    pipeline_id=pipeline_id
                )

            @app.get(
                "/inference_pipelines/{pipeline_id}/consume",
                response_model=ConsumePipelineResponse,
                summary="[EXPERIMENTAL] Consumes InferencePipeline result",
                description="[EXPERIMENTAL] Consumes InferencePipeline result",
            )
            @with_route_exceptions_async
            async def consume(
                pipeline_id: str,
                request: Optional[ConsumeResultsPayload] = None,
            ) -> ConsumePipelineResponse:
                if request is None:
                    request = ConsumeResultsPayload()
                return await self.stream_manager_client.consume_pipeline_result(
                    pipeline_id=pipeline_id,
                    excluded_fields=request.excluded_fields,
                )

        class ModelInitState:
            """Class to track model initialization state."""

            def __init__(self):
                self.is_ready = False
                self.lock = Lock()  # For thread-safe updates
                self.initialization_errors = []  # Track errors per model

        model_init_state = ModelInitState()

        should_preload = PRELOAD_MODELS or PINNED_MODELS
        if not should_preload:
            model_init_state.is_ready = True

        # Enable preloading models at startup
        if should_preload:

            def initialize_models(state: ModelInitState):
                """Perform asynchronous initialization tasks to load models."""

                def load_model(model_id):
                    t_start = time.perf_counter()
                    de_aliased = resolve_roboflow_model_alias(model_id=model_id)
                    logger.info(
                        f"Preload: starting model load for '{model_id}' (resolved: '{de_aliased}')"
                    )
                    try:
                        self.model_manager.add_model(
                            de_aliased,
                            PRELOAD_API_KEY,
                        )
                        load_time = time.perf_counter() - t_start
                        logger.info(
                            f"Preload: model '{model_id}' loaded successfully in {load_time:.1f}s"
                        )
                    except Exception as e:
                        load_time = time.perf_counter() - t_start
                        error_msg = f"Preload: error loading model '{model_id}' after {load_time:.1f}s: {e}"
                        logger.error(error_msg)
                        with state.lock:
                            state.initialization_errors.append((model_id, str(e)))
                        return

                    # Pin if this model is in PINNED_MODELS
                    if (
                        PINNED_MODELS
                        and model_id in PINNED_MODELS
                        and hasattr(self.model_manager, "pin_model")
                    ):
                        self.model_manager.pin_model(de_aliased)

                all_models = list(
                    dict.fromkeys((PRELOAD_MODELS or []) + (PINNED_MODELS or []))
                )
                if all_models:
                    # Create tasks for each model to be loaded
                    model_loading_executor = ThreadPoolExecutor(max_workers=2)
                    loaded_futures: List[Tuple[str, Future]] = []
                    for model_id in all_models:
                        future = model_loading_executor.submit(
                            load_model, model_id=model_id
                        )
                        loaded_futures.append((model_id, future))

                    for model_id, future in loaded_futures:
                        try:
                            future.result(timeout=300)
                        except (
                            TimeoutError,
                            CancelledError,
                            concurrent.futures.TimeoutError,
                        ):
                            state.initialization_errors.append(
                                (
                                    model_id,
                                    "Could not finalise model loading before timeout",
                                )
                            )
                            future.cancel()
                        except Exception as e:
                            logger.error(
                                f"Preload: unexpected error for model '{model_id}': {e}"
                            )
                            with state.lock:
                                state.initialization_errors.append((model_id, str(e)))

                # Update the readiness state in a thread-safe manner
                with state.lock:
                    state.is_ready = True

            @app.on_event("startup")
            def startup_model_init():
                """Initialize the models on startup."""
                startup_thread = Thread(
                    target=initialize_models, args=(model_init_state,), daemon=True
                )
                startup_thread.start()
                logger.info("Model initialization started in the background.")

        # Attach health/readiness endpoints
        @app.get("/readiness", status_code=200)
        def readiness(
            state: ModelInitState = Depends(lambda: model_init_state),
        ):
            """Readiness endpoint for Kubernetes readiness probe."""
            with state.lock:
                if state.is_ready:
                    return {"status": "ready"}
                else:
                    return JSONResponse(
                        content={"status": "not ready"}, status_code=503
                    )

        @app.get("/healthz", status_code=200)
        def healthz():
            """Health endpoint for Kubernetes liveness probe."""
            return {"status": "healthy"}

        if CORE_MODELS_ENABLED:
            if CORE_MODEL_CLIP_ENABLED:

                @app.post(
                    "/clip/embed_image",
                    response_model=ClipEmbeddingResponse,
                    summary="CLIP Image Embeddings",
                    description="Run the Open AI CLIP model to embed image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def clip_embed_image(
                    inference_request: ClipImageEmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the OpenAI CLIP model.

                    Args:
                        inference_request (ClipImageEmbeddingRequest): The request containing the image to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        ClipEmbeddingResponse: The response containing the embedded image.
                    """
                    logger.debug(f"Reached /clip/embed_image")
                    clip_model_id = load_clip_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        clip_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(clip_model_id, actor)
                    return response

                @app.post(
                    "/clip/embed_text",
                    response_model=ClipEmbeddingResponse,
                    summary="CLIP Text Embeddings",
                    description="Run the Open AI CLIP model to embed text data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def clip_embed_text(
                    inference_request: ClipTextEmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds text data using the OpenAI CLIP model.

                    Args:
                        inference_request (ClipTextEmbeddingRequest): The request containing the text to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        ClipEmbeddingResponse: The response containing the embedded text.
                    """
                    logger.debug(f"Reached /clip/embed_text")
                    clip_model_id = load_clip_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        clip_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(clip_model_id, actor)
                    return response

                @app.post(
                    "/clip/compare",
                    response_model=ClipCompareResponse,
                    summary="CLIP Compare",
                    description="Run the Open AI CLIP model to compute similarity scores.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def clip_compare(
                    inference_request: ClipCompareRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Computes similarity scores using the OpenAI CLIP model.

                    Args:
                        inference_request (ClipCompareRequest): The request containing the data to be compared.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        ClipCompareResponse: The response containing the similarity scores.
                    """
                    logger.debug(f"Reached /clip/compare")
                    clip_model_id = load_clip_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        clip_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(clip_model_id, actor, n=2)
                    return response

            if CORE_MODEL_PE_ENABLED:

                @app.post(
                    "/perception_encoder/embed_image",
                    response_model=PerceptionEncoderEmbeddingResponse,
                    summary="PE Image Embeddings",
                    description="Run the Meta Perception Encoder model to embed image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def pe_embed_image(
                    inference_request: PerceptionEncoderImageEmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the Perception Encoder PE model.

                    Args:
                        inference_request (PerceptionEncoderImageEmbeddingRequest): The request containing the image to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        PerceptionEncoderEmbeddingResponse: The response containing the embedded image.
                    """
                    logger.debug(f"Reached /perception_encoder/embed_image")
                    pe_model_id = load_pe_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        pe_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(pe_model_id, actor)
                    return response

                @app.post(
                    "/perception_encoder/embed_text",
                    response_model=PerceptionEncoderEmbeddingResponse,
                    summary="Perception Encoder Text Embeddings",
                    description="Run the Meta Perception Encoder model to embed text data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def pe_embed_text(
                    inference_request: PerceptionEncoderTextEmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds text data using the Meta Perception Encoder model.

                    Args:
                        inference_request (PerceptionEncoderTextEmbeddingRequest): The request containing the text to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        PerceptionEncoderEmbeddingResponse: The response containing the embedded text.
                    """
                    logger.debug(f"Reached /perception_encoder/embed_text")
                    pe_model_id = load_pe_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        pe_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(pe_model_id, actor)
                    return response

                @app.post(
                    "/perception_encoder/compare",
                    response_model=PerceptionEncoderCompareResponse,
                    summary="Perception Encoder Compare",
                    description="Run the Meta Perception Encoder model to compute similarity scores.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def pe_compare(
                    inference_request: PerceptionEncoderCompareRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Computes similarity scores using the Meta Perception Encoder model.

                    Args:
                        inference_request (PerceptionEncoderCompareRequest): The request containing the data to be compared.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        PerceptionEncoderCompareResponse: The response containing the similarity scores.
                    """
                    logger.debug(f"Reached /perception_encoder/compare")
                    pe_model_id = load_pe_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        pe_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(pe_model_id, actor, n=2)
                    return response

            if CORE_MODEL_GROUNDINGDINO_ENABLED:

                @app.post(
                    "/grounding_dino/infer",
                    response_model=ObjectDetectionInferenceResponse,
                    summary="Grounding DINO inference.",
                    description="Run the Grounding DINO zero-shot object detection model.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def grounding_dino_infer(
                    inference_request: GroundingDINOInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the Grounding DINO model.

                    Args:
                        inference_request GroundingDINOInferenceRequest): The request containing the image on which to run object detection.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        ObjectDetectionInferenceResponse: The object detection response.
                    """
                    logger.debug(f"Reached /grounding_dino/infer")
                    grounding_dino_model_id = load_grounding_dino_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        grounding_dino_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(grounding_dino_model_id, actor)
                    return response

            if CORE_MODEL_YOLO_WORLD_ENABLED:

                @app.post(
                    "/yolo_world/infer",
                    response_model=ObjectDetectionInferenceResponse,
                    summary="YOLO-World inference.",
                    description="Run the YOLO-World zero-shot object detection model.",
                    response_model_exclude_none=True,
                )
                @with_route_exceptions
                @usage_collector("request")
                def yolo_world_infer(
                    inference_request: YOLOWorldInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Runs the YOLO-World zero-shot object detection model.

                    Args:
                        inference_request (YOLOWorldInferenceRequest): The request containing the image on which to run object detection.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        ObjectDetectionInferenceResponse: The object detection response.
                    """
                    logger.debug(f"Reached /yolo_world/infer. Loading model")
                    yolo_world_model_id = load_yolo_world_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    logger.debug("YOLOWorld model loaded. Staring the inference.")
                    response = self.model_manager.infer_from_request_sync(
                        yolo_world_model_id, inference_request
                    )
                    logger.debug("YOLOWorld prediction available.")
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(yolo_world_model_id, actor)
                        logger.debug("Usage of YOLOWorld denoted.")
                    return response

            if CORE_MODEL_DOCTR_ENABLED:

                @app.post(
                    "/doctr/ocr",
                    response_model=Union[
                        OCRInferenceResponse, List[OCRInferenceResponse]
                    ],
                    summary="DocTR OCR response",
                    description="Run the DocTR OCR model to retrieve text in an image.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def doctr_retrieve_text(
                    inference_request: DoctrOCRInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the DocTR model.

                    Args:
                        inference_request (M.DoctrOCRInferenceRequest): The request containing the image from which to retrieve text.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        OCRInferenceResponse: The response containing the embedded image.
                    """
                    logger.debug(f"Reached /doctr/ocr")
                    doctr_model_id = load_doctr_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        doctr_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(doctr_model_id, actor)
                    return orjson_response_keeping_parent_id(response)

            if CORE_MODEL_EASYOCR_ENABLED:

                @app.post(
                    "/easy_ocr/ocr",
                    response_model=Union[
                        OCRInferenceResponse, List[OCRInferenceResponse]
                    ],
                    summary="EasyOCR OCR response",
                    description="Run the EasyOCR model to retrieve text in an image.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def easy_ocr_retrieve_text(
                    inference_request: EasyOCRInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the EasyOCR model.

                    Args:
                        inference_request (EasyOCRInferenceRequest): The request containing the image from which to retrieve text.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        OCRInferenceResponse: The response containing the embedded image.
                    """
                    logger.debug(f"Reached /easy_ocr/ocr")
                    easy_ocr_model_id = load_easy_ocr_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        easy_ocr_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(easy_ocr_model_id, actor)
                    return orjson_response_keeping_parent_id(response)

            if CORE_MODEL_SAM_ENABLED:

                @app.post(
                    "/sam/embed_image",
                    response_model=SamEmbeddingResponse,
                    summary="SAM Image Embeddings",
                    description="Run the Meta AI Segmant Anything Model to embed image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam_embed_image(
                    inference_request: SamEmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the Meta AI Segmant Anything Model (SAM).

                    Args:
                        inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.SamEmbeddingResponse or Response: The response containing the embedded image.
                    """
                    logger.debug(f"Reached /sam/embed_image")
                    sam_model_id = load_sam_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    model_response = self.model_manager.infer_from_request_sync(
                        sam_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(sam_model_id, actor)
                    if inference_request.format == "binary":
                        return Response(
                            content=model_response.embeddings,
                            headers={"Content-Type": "application/octet-stream"},
                        )
                    return model_response

                @app.post(
                    "/sam/segment_image",
                    response_model=SamSegmentationResponse,
                    summary="SAM Image Segmentation",
                    description="Run the Meta AI Segmant Anything Model to generate segmenations for image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam_segment_image(
                    inference_request: SamSegmentationRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Generates segmentations for image data using the Meta AI Segmant Anything Model (SAM).

                    Args:
                        inference_request (SamSegmentationRequest): The request containing the image to be segmented.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.SamSegmentationResponse or Response: The response containing the segmented image.
                    """
                    logger.debug(f"Reached /sam/segment_image")
                    sam_model_id = load_sam_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    model_response = self.model_manager.infer_from_request_sync(
                        sam_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(sam_model_id, actor)
                    if inference_request.format == "binary":
                        return Response(
                            content=model_response,
                            headers={"Content-Type": "application/octet-stream"},
                        )
                    return model_response

            if CORE_MODEL_SAM2_ENABLED:

                @app.post(
                    "/sam2/embed_image",
                    response_model=Sam2EmbeddingResponse,
                    summary="SAM2 Image Embeddings",
                    description="Run the Meta AI Segment Anything 2 Model to embed image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam2_embed_image(
                    inference_request: Sam2EmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the Meta AI Segment Anything Model (SAM).

                    Args:
                        inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.Sam2EmbeddingResponse or Response: The response affirming the image has been embedded
                    """
                    logger.debug(f"Reached /sam2/embed_image")
                    sam2_model_id = load_sam2_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    model_response = self.model_manager.infer_from_request_sync(
                        sam2_model_id, inference_request
                    )
                    return model_response

                @app.post(
                    "/sam2/segment_image",
                    response_model=Sam2SegmentationResponse,
                    summary="SAM2 Image Segmentation",
                    description="Run the Meta AI Segment Anything 2 Model to generate segmenations for image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam2_segment_image(
                    inference_request: Sam2SegmentationRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Generates segmentations for image data using the Meta AI Segment Anything Model (SAM).

                    Args:
                        inference_request (Sam2SegmentationRequest): The request containing the image to be segmented.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.SamSegmentationResponse or Response: The response containing the segmented image.
                    """
                    logger.debug(f"Reached /sam2/segment_image")
                    sam2_model_id = load_sam2_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    model_response = self.model_manager.infer_from_request_sync(
                        sam2_model_id, inference_request
                    )
                    if inference_request.format == "binary":
                        return Response(
                            content=model_response,
                            headers={"Content-Type": "application/octet-stream"},
                        )
                    return model_response

            if CORE_MODEL_SAM3_ENABLED and not GCP_SERVERLESS:

                @app.post(
                    "/sam3/embed_image",
                    response_model=Sam3EmbeddingResponse,
                    summary="Seg preview Image Embeddings",
                    description="Run the  Model to embed image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam3_embed_image(
                    inference_request: Sam2EmbeddingRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    logger.debug(f"Reached /sam3/embed_image")

                    if SAM3_EXEC_MODE == "remote":
                        raise HTTPException(
                            status_code=501,
                            detail="SAM3 embedding is not supported in remote execution mode.",
                        )

                    self.model_manager.add_model(
                        "sam3/sam3_interactive",
                        api_key=api_key,
                        endpoint_type=ModelEndpointType.CORE_MODEL,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

                    model_response = self.model_manager.infer_from_request_sync(
                        "sam3/sam3_interactive", inference_request
                    )
                    return model_response

            if CORE_MODEL_SAM3_ENABLED:

                @app.post(
                    "/sam3/concept_segment",
                    response_model=Sam3SegmentationResponse,
                    summary="SAM3 PCS (promptable concept segmentation)",
                    description="Run the SAM3 PCS (promptable concept segmentation) to generate segmentations for image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam3_segment_image(
                    inference_request: Sam3SegmentationRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    if not SAM3_FINE_TUNED_MODELS_ENABLED:
                        if not inference_request.model_id.startswith("sam3/"):
                            raise HTTPException(
                                status_code=501,
                                detail="Fine-tuned SAM3 models are not supported on this deployment. Please use a workflow or self-host the server.",
                            )

                    if SAM3_EXEC_MODE == "remote":
                        endpoint = f"{API_BASE_URL}/inferenceproxy/seg-preview"

                        # Construct payload for remote API
                        # The remote API expects:
                        # {
                        #     "image": {"type": "base64", "value": ...},
                        #     "prompts": [{"type": "text", "text": ...}, ...],
                        #     "output_prob_thresh": ...
                        # }

                        # Extract prompts from request
                        http_prompts = []
                        for prompt in inference_request.prompts:
                            p_dict = prompt.dict(exclude_none=True)
                            # Ensure type is set if missing (default to text if text is present)
                            if "type" not in p_dict:
                                if "text" in p_dict:
                                    p_dict["type"] = "text"
                            http_prompts.append(p_dict)

                        # Prepare image
                        # inference_request.image is InferenceRequestImage
                        if inference_request.image.type == "base64":
                            http_image = {
                                "type": "base64",
                                "value": inference_request.image.value,
                            }
                        elif inference_request.image.type == "url":
                            http_image = {
                                "type": "url",
                                "value": inference_request.image.value,
                            }
                        elif inference_request.image.type == "numpy":
                            # Numpy not supported for remote proxy easily without serialization,
                            # but InferenceRequestImage usually comes as base64/url in HTTP API.
                            # If it is numpy, we might need to handle it, but for now assume base64/url.
                            # If it's numpy, it's likely from internal call, but this is HTTP API.
                            http_image = {
                                "type": "numpy",
                                "value": inference_request.image.value,
                            }
                        else:
                            http_image = {
                                "type": inference_request.image.type,
                                "value": inference_request.image.value,
                            }

                        payload = {
                            "image": http_image,
                            "prompts": http_prompts,
                            "output_prob_thresh": inference_request.output_prob_thresh,
                        }

                        try:
                            headers = {"Content-Type": "application/json"}
                            if ROBOFLOW_INTERNAL_SERVICE_NAME:
                                headers["X-Roboflow-Internal-Service-Name"] = (
                                    ROBOFLOW_INTERNAL_SERVICE_NAME
                                )
                            if ROBOFLOW_INTERNAL_SERVICE_SECRET:
                                headers["X-Roboflow-Internal-Service-Secret"] = (
                                    ROBOFLOW_INTERNAL_SERVICE_SECRET
                                )

                            headers = build_roboflow_api_headers(
                                explicit_headers=headers
                            )

                            response = requests.post(
                                f"{endpoint}?api_key={api_key}",
                                json=payload,
                                headers=headers,
                                timeout=60,
                            )
                            response.raise_for_status()
                            resp_json = response.json()

                            # The remote API returns the same structure as Sam3SegmentationResponse
                            return Sam3SegmentationResponse(**resp_json)

                        except Exception as e:
                            logger.error(f"SAM3 remote request failed: {e}")
                            raise HTTPException(
                                status_code=500,
                                detail=f"SAM3 remote request failed: {str(e)}",
                            )

                    if inference_request.model_id.startswith("sam3/"):
                        self.model_manager.add_model(
                            inference_request.model_id,
                            api_key=api_key,
                            endpoint_type=ModelEndpointType.CORE_MODEL,
                            countinference=countinference,
                            service_secret=service_secret,
                        )
                    else:
                        self.model_manager.add_model(
                            inference_request.model_id,
                            api_key=api_key,
                            endpoint_type=ModelEndpointType.ORT,
                            countinference=countinference,
                            service_secret=service_secret,
                        )

                    model_response = self.model_manager.infer_from_request_sync(
                        inference_request.model_id, inference_request
                    )
                    if inference_request.format == "binary":
                        return Response(
                            content=model_response,
                            headers={"Content-Type": "application/octet-stream"},
                        )
                    return model_response

                @app.post(
                    "/sam3/visual_segment",
                    response_model=Sam2SegmentationResponse,
                    summary="SAM3 PVS (promptable visual segmentation)",
                    description="Run the SAM3 PVS (promptable visual segmentation) to generate segmentations for image data.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam3_visual_segment(
                    inference_request: Sam2SegmentationRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    logger.debug(f"Reached /sam3/visual_segment")

                    if SAM3_EXEC_MODE == "remote":
                        endpoint = f"{API_BASE_URL}/inferenceproxy/sam3-pvs"

                        http_image = {
                            "type": inference_request.image.type,
                            "value": inference_request.image.value,
                        }

                        prompts_data = (
                            inference_request.prompts.dict(exclude_none=True)
                            if inference_request.prompts
                            else None
                        )

                        payload = {
                            "image": http_image,
                            "prompts": prompts_data,
                            "multimask_output": inference_request.multimask_output,
                        }

                        try:
                            headers = {"Content-Type": "application/json"}
                            if ROBOFLOW_INTERNAL_SERVICE_NAME:
                                headers["X-Roboflow-Internal-Service-Name"] = (
                                    ROBOFLOW_INTERNAL_SERVICE_NAME
                                )
                            if ROBOFLOW_INTERNAL_SERVICE_SECRET:
                                headers["X-Roboflow-Internal-Service-Secret"] = (
                                    ROBOFLOW_INTERNAL_SERVICE_SECRET
                                )

                            headers = build_roboflow_api_headers(
                                explicit_headers=headers
                            )

                            response = requests.post(
                                f"{endpoint}?api_key={api_key}",
                                json=payload,
                                headers=headers,
                                timeout=60,
                            )
                            response.raise_for_status()
                            resp_json = response.json()

                            return Sam2SegmentationResponse(**resp_json)

                        except Exception as e:
                            logger.error(
                                f"SAM3 visual_segment remote request failed: {e}"
                            )
                            raise HTTPException(
                                status_code=500,
                                detail=f"SAM3 visual_segment remote request failed: {str(e)}",
                            )

                    self.model_manager.add_model(
                        "sam3/sam3_interactive",
                        api_key=api_key,
                        endpoint_type=ModelEndpointType.CORE_MODEL,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

                    model_response = self.model_manager.infer_from_request_sync(
                        "sam3/sam3_interactive", inference_request
                    )
                    return model_response

            if CORE_MODEL_SAM3_ENABLED and not GCP_SERVERLESS:

                @app.post(
                    "/sam3_3d/infer",
                    summary="SAM3 3D Object Generation",
                    description="Generate 3D meshes and Gaussian splatting from 2D images with mask prompts.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def sam3_3d_infer(
                    inference_request: Sam3_3D_Objects_InferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """Generate 3D meshes and Gaussian splatting from 2D images with mask prompts.

                    Args:
                        inference_request (Sam3_3D_Objects_InferenceRequest): The request containing
                            the image and mask input for 3D generation.
                        api_key (Optional[str]): Roboflow API Key for artifact retrieval.

                    Returns:
                        dict: Response containing base64-encoded 3D outputs:
                            - mesh_glb: Scene mesh in GLB format (base64)
                            - gaussian_ply: Combined Gaussian splatting in PLY format (base64)
                            - objects: List of individual objects with their 3D data
                            - time: Inference time in seconds
                    """
                    logger.debug("Reached /sam3_3d/infer")
                    model_id = inference_request.model_id or "sam3-3d-objects"

                    self.model_manager.add_model(
                        model_id,
                        api_key=api_key,
                        endpoint_type=ModelEndpointType.CORE_MODEL,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

                    model_response = self.model_manager.infer_from_request_sync(
                        model_id, inference_request
                    )

                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(model_id, actor)

                    # Convert bytes to base64 for JSON serialization
                    def encode_bytes(data):
                        if data is None:
                            return None
                        return base64.b64encode(data).decode("utf-8")

                    objects_list = []
                    for obj in model_response.objects:
                        objects_list.append(
                            {
                                "mesh_glb": encode_bytes(obj.mesh_glb),
                                "gaussian_ply": encode_bytes(obj.gaussian_ply),
                                "metadata": {
                                    "rotation": obj.metadata.rotation,
                                    "translation": obj.metadata.translation,
                                    "scale": obj.metadata.scale,
                                },
                            }
                        )

                    return {
                        "mesh_glb": encode_bytes(model_response.mesh_glb),
                        "gaussian_ply": encode_bytes(model_response.gaussian_ply),
                        "objects": objects_list,
                        "time": model_response.time,
                    }

            if CORE_MODEL_OWLV2_ENABLED:

                @app.post(
                    "/owlv2/infer",
                    response_model=ObjectDetectionInferenceResponse,
                    summary="Owlv2 image prompting",
                    description="Run the google owlv2 model to few-shot object detect",
                )
                @with_route_exceptions
                @usage_collector("request")
                def owlv2_infer(
                    inference_request: OwlV2InferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Embeds image data using the Meta AI Segmant Anything Model (SAM).

                    Args:
                        inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.Sam2EmbeddingResponse or Response: The response affirming the image has been embedded
                    """
                    logger.debug(f"Reached /owlv2/infer")
                    owl2_model_id = load_owlv2_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    model_response = self.model_manager.infer_from_request_sync(
                        owl2_model_id, inference_request
                    )
                    return model_response

            if CORE_MODEL_GAZE_ENABLED:

                @app.post(
                    "/gaze/gaze_detection",
                    response_model=List[GazeDetectionInferenceResponse],
                    summary="Gaze Detection",
                    description="Run the gaze detection model to detect gaze.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def gaze_detection(
                    inference_request: GazeDetectionInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Detect gaze using the gaze detection model.

                    Args:
                        inference_request (M.GazeDetectionRequest): The request containing the image to be detected.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        M.GazeDetectionResponse: The response containing all the detected faces and the corresponding gazes.
                    """
                    logger.debug(f"Reached /gaze/gaze_detection")
                    gaze_model_id = load_gaze_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        gaze_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(gaze_model_id, actor)
                    return response

            if DEPTH_ESTIMATION_ENABLED:

                @app.post(
                    "/infer/depth-estimation",
                    response_model=DepthEstimationResponse,
                    summary="Depth Estimation",
                    description="Run the depth estimation model to generate a depth map.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def depth_estimation(
                    inference_request: DepthEstimationRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Generate a depth map using the depth estimation model.

                    Args:
                        inference_request (DepthEstimationRequest): The request containing the image to estimate depth for.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        DepthEstimationResponse: The response containing the normalized depth map and optional visualization.
                    """
                    logger.debug(f"Reached /infer/depth-estimation")
                    depth_model_id = inference_request.model_id
                    self.model_manager.add_model(
                        depth_model_id,
                        inference_request.api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        depth_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(depth_model_id, actor)

                    # Extract data from nested response structure
                    depth_data = response.response
                    depth_response = DepthEstimationResponse(
                        normalized_depth=depth_data["normalized_depth"].tolist(),
                        image=depth_data["image"].base64_image,
                    )
                    return depth_response

            if CORE_MODEL_TROCR_ENABLED:

                @app.post(
                    "/ocr/trocr",
                    response_model=OCRInferenceResponse,
                    summary="TrOCR OCR response",
                    description="Run the TrOCR model to retrieve text in an image.",
                )
                @with_route_exceptions
                @usage_collector("request")
                def trocr_retrieve_text(
                    inference_request: TrOCRInferenceRequest,
                    request: Request,
                    api_key: Optional[str] = Query(
                        None,
                        description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                    ),
                    countinference: Optional[bool] = None,
                    service_secret: Optional[str] = None,
                ):
                    """
                    Retrieves text from image data using the TrOCR model.

                    Args:
                        inference_request (TrOCRInferenceRequest): The request containing the image from which to retrieve text.
                        api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                        request (Request, default Body()): The HTTP request.

                    Returns:
                        OCRInferenceResponse: The response containing the retrieved text.
                    """
                    logger.debug(f"Reached /trocr/ocr")
                    trocr_model_id = load_trocr_model(
                        inference_request,
                        api_key=api_key,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    response = self.model_manager.infer_from_request_sync(
                        trocr_model_id, inference_request
                    )
                    if LAMBDA:
                        actor = request.scope["aws.event"]["requestContext"][
                            "authorizer"
                        ]["lambda"]["actor"]
                        trackUsage(trocr_model_id, actor)
                    return orjson_response_keeping_parent_id(response)

        if not (LAMBDA or GCP_SERVERLESS):

            @app.get(
                "/notebook/start",
                summary="Jupyter Lab Server Start",
                description="Starts a jupyter lab server for running development code",
            )
            @with_route_exceptions
            def notebook_start(browserless: bool = False):
                """Starts a jupyter lab server for running development code.

                Args:
                    inference_request (NotebookStartRequest): The request containing the necessary details for starting a jupyter lab server.
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks

                Returns:
                    NotebookStartResponse: The response containing the URL of the jupyter lab server.
                """
                logger.debug(f"Reached /notebook/start")
                if NOTEBOOK_ENABLED:
                    start_notebook()
                    if browserless:
                        return {
                            "success": True,
                            "message": f"Jupyter Lab server started at http://localhost:{NOTEBOOK_PORT}?token={NOTEBOOK_PASSWORD}",
                        }
                    else:
                        sleep(2)
                        return RedirectResponse(
                            f"http://localhost:{NOTEBOOK_PORT}/lab/tree/quickstart.ipynb?token={NOTEBOOK_PASSWORD}"
                        )
                else:
                    if browserless:
                        return {
                            "success": False,
                            "message": "Notebook server is not enabled. Enable notebooks via the NOTEBOOK_ENABLED environment variable.",
                        }
                    else:
                        return RedirectResponse(f"/notebook-instructions.html")

        if ENABLE_BUILDER:
            from inference.core.interfaces.http.builder.routes import (
                router as builder_router,
            )

            # Allow CORS on builder API and workflow endpoints needed by the builder UI
            # Enables Private Network Access for Chrome 142+ (local development)
            app.add_middleware(
                PathAwareCORSMiddleware,
                match_paths=r"^/(build/api|workflows/).*",
                allow_origins=[BUILDER_ORIGIN],
                allow_methods=["*"],
                allow_headers=["*"],
                allow_credentials=True,
                allow_private_network=True,
            )

            # Attach all routes from builder to the /build prefix
            app.include_router(builder_router, prefix="/build", tags=["builder"])

        if LEGACY_ROUTE_ENABLED:
            # Legacy object detection inference path for backwards compatibility
            @app.get(
                "/{dataset_id}/{version_id:str}",
                # Order matters in this response model Union. It will use the first matching model. For example, Object Detection Inference Response is a subset of Instance segmentation inference response, so instance segmentation must come first in order for the matching logic to work.
                response_model=Union[
                    InstanceSegmentationInferenceResponse,
                    KeypointsDetectionInferenceResponse,
                    ObjectDetectionInferenceResponse,
                    ClassificationInferenceResponse,
                    MultiLabelClassificationInferenceResponse,
                    SemanticSegmentationInferenceResponse,
                    StubResponse,
                    Any,
                ],
                response_model_exclude_none=True,
            )
            @app.post(
                "/{dataset_id}/{version_id:str}",
                # Order matters in this response model Union. It will use the first matching model. For example, Object Detection Inference Response is a subset of Instance segmentation inference response, so instance segmentation must come first in order for the matching logic to work.
                response_model=Union[
                    InstanceSegmentationInferenceResponse,
                    KeypointsDetectionInferenceResponse,
                    ObjectDetectionInferenceResponse,
                    ClassificationInferenceResponse,
                    MultiLabelClassificationInferenceResponse,
                    SemanticSegmentationInferenceResponse,
                    StubResponse,
                    Any,
                ],
                response_model_exclude_none=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def legacy_infer_from_request(
                background_tasks: BackgroundTasks,
                request: Request,
                request_body: Annotated[
                    Optional[Union[bytes, UploadFile]],
                    Depends(parse_body_content_for_legacy_request_handler),
                ],
                dataset_id: str = Path(
                    description="ID of a Roboflow dataset corresponding to the model to use for inference OR workspace ID"
                ),
                version_id: str = Path(
                    description="ID of a Roboflow dataset version corresponding to the model to use for inference OR model ID"
                ),
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                confidence: float = Query(
                    0.4,
                    description="The confidence threshold used to filter out predictions",
                ),
                keypoint_confidence: float = Query(
                    0.0,
                    description="The confidence threshold used to filter out keypoints that are not visible based on model confidence",
                ),
                format: str = Query(
                    "json",
                    description="One of 'json' or 'image'. If 'json' prediction data is return as a JSON string. If 'image' prediction data is visualized and overlayed on the original input image.",
                ),
                image: Optional[str] = Query(
                    None,
                    description="The publically accessible URL of an image to use for inference.",
                ),
                image_type: Optional[str] = Query(
                    "base64",
                    description="One of base64 or numpy. Note, numpy input is not supported for Roboflow Hosted Inference.",
                ),
                labels: Optional[bool] = Query(
                    False,
                    description="If true, labels will be include in any inference visualization.",
                ),
                mask_decode_mode: Optional[str] = Query(
                    "accurate",
                    description="One of 'accurate' or 'fast'. If 'accurate' the mask will be decoded using the original image size. If 'fast' the mask will be decoded using the original mask size. 'accurate' is slower but more accurate.",
                ),
                tradeoff_factor: Optional[float] = Query(
                    0.0,
                    description="The amount to tradeoff between 0='fast' and 1='accurate'",
                ),
                max_detections: int = Query(
                    300,
                    description="The maximum number of detections to return. This is used to limit the number of predictions returned by the model. The model may return more predictions than this number, but only the top `max_detections` predictions will be returned.",
                ),
                overlap: float = Query(
                    0.3,
                    description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS",
                ),
                stroke: int = Query(
                    1, description="The stroke width used when visualizing predictions"
                ),
                countinference: Optional[bool] = Query(
                    True,
                    description="If false, does not track inference against usage.",
                    include_in_schema=False,
                ),
                service_secret: Optional[str] = Query(
                    None,
                    description="Shared secret used to authenticate requests to the inference server from internal services (e.g. to allow disabling inference usage tracking via the `countinference` query parameter)",
                    include_in_schema=False,
                ),
                disable_preproc_auto_orient: Optional[bool] = Query(
                    False, description="If true, disables automatic image orientation"
                ),
                disable_preproc_contrast: Optional[bool] = Query(
                    False, description="If true, disables automatic contrast adjustment"
                ),
                disable_preproc_grayscale: Optional[bool] = Query(
                    False,
                    description="If true, disables automatic grayscale conversion",
                ),
                disable_preproc_static_crop: Optional[bool] = Query(
                    False, description="If true, disables automatic static crop"
                ),
                disable_active_learning: Optional[bool] = Query(
                    default=False,
                    description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)",
                ),
                active_learning_target_dataset: Optional[str] = Query(
                    default=None,
                    description="Parameter to be used when Active Learning data registration should happen against different dataset than the one pointed by model_id",
                ),
                source: Optional[str] = Query(
                    "external",
                    description="The source of the inference request",
                ),
                source_info: Optional[str] = Query(
                    "external",
                    description="The detailed source information of the inference request",
                ),
                disable_model_monitoring: Optional[bool] = Query(
                    False,
                    description="If true, disables model monitoring for this request",
                    include_in_schema=False,
                ),
            ):
                """
                Legacy inference endpoint for object detection, instance segmentation, and classification.

                Args:
                    background_tasks: (BackgroundTasks) pool of fastapi background tasks
                    dataset_id (str): ID of a Roboflow dataset corresponding to the model to use for inference OR workspace ID
                    version_id (str): ID of a Roboflow dataset version corresponding to the model to use for inference OR model ID
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    # Other parameters described in the function signature...

                Returns:
                    Union[InstanceSegmentationInferenceResponse, KeypointsDetectionInferenceRequest, ObjectDetectionInferenceResponse, ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse, SemanticSegmentationInferenceResponse, Any]: The response containing the inference results.
                """
                logger.debug(
                    f"Reached legacy route /:dataset_id/:version_id with {dataset_id}/{version_id}"
                )
                model_id = f"{dataset_id}/{version_id}"
                if confidence >= 1:
                    confidence /= 100
                if confidence < CONFIDENCE_LOWER_BOUND_OOM_PREVENTION:
                    # allowing lower confidence results in RAM usage explosion
                    confidence = CONFIDENCE_LOWER_BOUND_OOM_PREVENTION

                if overlap >= 1:
                    overlap /= 100
                if image is not None:
                    request_image = InferenceRequestImage(type="url", value=image)
                else:
                    if "Content-Type" not in request.headers:
                        raise ContentTypeMissing(
                            f"Request must include a Content-Type header"
                        )
                    if isinstance(request_body, UploadFile):
                        base64_image_str = request_body.file.read()
                        base64_image_str = base64.b64encode(base64_image_str)
                        request_image = InferenceRequestImage(
                            type="base64", value=base64_image_str.decode("ascii")
                        )
                    elif isinstance(request_body, bytes):
                        request_image = InferenceRequestImage(
                            type=image_type, value=request_body
                        )
                    elif request_body is None:
                        raise InputImageLoadError(
                            message="Image not found in request body.",
                            public_message="Image not found in request body.",
                        )
                    else:
                        raise ContentTypeInvalid(
                            f"Invalid Content-Type: {request.headers['Content-Type']}"
                        )

                if not countinference and service_secret != ROBOFLOW_SERVICE_SECRET:
                    raise MissingServiceSecretError(
                        "Service secret is required to disable inference usage tracking"
                    )
                if LAMBDA:
                    logger.debug("request.scope: %s", request.scope)
                    request_model_id = (
                        request.scope["aws.event"]["requestContext"]["authorizer"][
                            "lambda"
                        ]["model"]["endpoint"]
                        .replace("--", "/")
                        .replace("rf-", "")
                        .replace("nu-", "")
                    )
                    actor = request.scope["aws.event"]["requestContext"]["authorizer"][
                        "lambda"
                    ]["actor"]
                    if countinference:
                        trackUsage(request_model_id, actor)
                    else:
                        if service_secret != ROBOFLOW_SERVICE_SECRET:
                            raise MissingServiceSecretError(
                                "Service secret is required to disable inference usage tracking"
                            )
                        logger.info("Not counting inference for usage")
                else:
                    request_model_id = model_id
                logger.debug(
                    f"State of model registry: {self.model_manager.describe_models()}"
                )
                self.model_manager.add_model(
                    request_model_id,
                    api_key,
                    model_id_alias=model_id,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                task_type = self.model_manager.get_task_type(model_id, api_key=api_key)
                inference_request_type = ObjectDetectionInferenceRequest
                args = dict()
                if task_type == "instance-segmentation":
                    inference_request_type = InstanceSegmentationInferenceRequest
                    args = {
                        "mask_decode_mode": mask_decode_mode,
                        "tradeoff_factor": tradeoff_factor,
                    }
                elif task_type == "classification":
                    inference_request_type = ClassificationInferenceRequest
                elif task_type == "keypoint-detection":
                    inference_request_type = KeypointsDetectionInferenceRequest
                    args = {"keypoint_confidence": keypoint_confidence}
                elif task_type == "semantic-segmentation":
                    inference_request_type = SemanticSegmentationInferenceRequest
                inference_request = inference_request_type(
                    api_key=api_key,
                    model_id=model_id,
                    image=request_image,
                    confidence=confidence,
                    iou_threshold=overlap,
                    max_detections=max_detections,
                    visualization_labels=labels,
                    visualization_stroke_width=stroke,
                    visualize_predictions=(
                        format == "image" or format == "image_and_json"
                    ),
                    disable_preproc_auto_orient=disable_preproc_auto_orient,
                    disable_preproc_contrast=disable_preproc_contrast,
                    disable_preproc_grayscale=disable_preproc_grayscale,
                    disable_preproc_static_crop=disable_preproc_static_crop,
                    disable_active_learning=disable_active_learning,
                    active_learning_target_dataset=active_learning_target_dataset,
                    source=source,
                    source_info=source_info,
                    usage_billable=countinference,
                    disable_model_monitoring=disable_model_monitoring,
                    **args,
                )
                inference_response = self.model_manager.infer_from_request_sync(
                    inference_request.model_id,
                    inference_request,
                    active_learning_eligible=True,
                    background_tasks=background_tasks,
                )
                logger.debug("Response ready.")
                if format == "image":
                    return Response(
                        content=inference_response.visualization,
                        media_type="image/jpeg",
                    )
                else:
                    return orjson_response(inference_response)

        if not (LAMBDA or GCP_SERVERLESS):
            # Legacy clear cache endpoint for backwards compatibility
            @app.get("/clear_cache", response_model=str)
            def legacy_clear_cache():
                """
                Clears the model cache.

                This endpoint provides a way to clear the cache of loaded models.

                Returns:
                    str: A string indicating that the cache has been cleared.
                """
                logger.debug(f"Reached /clear_cache")
                model_clear()
                return "Cache Cleared"

            # Legacy add model endpoint for backwards compatibility
            @app.get("/start/{dataset_id}/{version_id}")
            def model_add_legacy(
                dataset_id: str,
                version_id: str,
                api_key: str = None,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Starts a model inference session.

                This endpoint initializes and starts an inference session for the specified model version.

                Args:
                    dataset_id (str): ID of a Roboflow dataset corresponding to the model.
                    version_id (str): ID of a Roboflow dataset version corresponding to the model.
                    api_key (str, optional): Roboflow API Key for artifact retrieval.
                    countinference (Optional[bool]): Whether to count inference or not.
                    service_secret (Optional[str]): The service secret for the request.

                Returns:
                    JSONResponse: A response object containing the status and a success message.
                """
                logger.debug(
                    f"Reached /start/{dataset_id}/{version_id} with {dataset_id}/{version_id}"
                )
                model_id = f"{dataset_id}/{version_id}"
                self.model_manager.add_model(
                    model_id,
                    api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                return JSONResponse(
                    {
                        "status": 200,
                        "message": "inference session started from local memory.",
                    }
                )

        if not ENABLE_DASHBOARD:

            @app.get("/dashboard.html")
            @app.head("/dashboard.html")
            async def dashboard_guard():
                return Response(status_code=404)

        @app.exception_handler(InputImageLoadError)
        async def unicorn_exception_handler(request: Request, exc: InputImageLoadError):
            return JSONResponse(
                status_code=400,
                content={
                    "message": f"Could not load input image. Cause: {exc.get_public_error_details()}"
                },
            )

        app.mount(
            "/",
            StaticFiles(directory="./inference/landing/out", html=True),
            name="root",
        )

    def run(self):
        uvicorn.run(self.app, host="127.0.0.1", port=8080)
Functions
__init__
__init__(model_manager, root_path=None)

Initializes the HttpInterface with given model manager and model registry.

Parameters:

Name Type Description Default
model_manager ModelManager

The manager for handling different models.

required
root_path Optional[str]

The root path for the FastAPI application.

None
Description

Deploy Roboflow trained models to nearly any compute environment!

Source code in inference/core/interfaces/http/http_api.py
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
def __init__(
    self,
    model_manager: ModelManager,
    root_path: Optional[str] = None,
):
    """
    Initializes the HttpInterface with given model manager and model registry.

    Args:
        model_manager (ModelManager): The manager for handling different models.
        root_path (Optional[str]): The root path for the FastAPI application.

    Description:
        Deploy Roboflow trained models to nearly any compute environment!
    """

    description = "Roboflow inference server"

    app = FastAPI(
        title="Roboflow Inference Server",
        description=description,
        version=__version__,
        terms_of_service="https://roboflow.com/terms",
        contact={
            "name": "Roboflow Inc.",
            "url": "https://roboflow.com/contact",
            "email": "help@roboflow.com",
        },
        license_info={
            "name": "Apache 2.0",
            "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
        },
        root_path=root_path,
    )
    # Ensure in-memory logging is initialized as early as possible for all runtimes
    try:
        from inference.core.logging.memory_handler import setup_memory_logging

        setup_memory_logging()
    except Exception:
        pass

    app.mount(
        "/static",
        StaticFiles(directory="./inference/landing/out/static", html=True),
        name="static",
    )
    app.mount(
        "/_next/static",
        StaticFiles(directory="./inference/landing/out/_next/static", html=True),
        name="_next_static",
    )

    @app.on_event("shutdown")
    async def on_shutdown():
        logger.info("Shutting down %s", description)
        await usage_collector.async_push_usage_payloads()

    self._instrumentator = InferenceInstrumentator(
        app, model_manager=model_manager, endpoint="/metrics"
    )
    if LAMBDA:
        app.add_middleware(LambdaMiddleware)
    if GCP_SERVERLESS:
        app.add_middleware(GCPServerlessMiddleware)

    if len(ALLOW_ORIGINS) > 0:
        # Add CORS Middleware (but not for /build**, which is controlled separately)
        app.add_middleware(
            PathAwareCORSMiddleware,
            match_paths=r"^(?!/build).*",
            allow_origins=ALLOW_ORIGINS,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
            expose_headers=[
                PROCESSING_TIME_HEADER,
                REMOTE_PROCESSING_TIME_HEADER,
                REMOTE_PROCESSING_TIMES_HEADER,
                MODEL_COLD_START_HEADER,
                MODEL_LOAD_TIME_HEADER,
                MODEL_LOAD_DETAILS_HEADER,
                MODEL_ID_HEADER,
                WORKFLOW_ID_HEADER,
                WORKSPACE_ID_HEADER,
            ]
            + ([EXECUTION_ID_HEADER] if EXECUTION_ID_HEADER is not None else []),
        )

    # Optionally add middleware for profiling the FastAPI server and underlying inference API code
    if PROFILE:
        app.add_middleware(
            CProfileMiddleware,
            enable=True,
            server_app=app,
            filename="/profile/output.pstats",
            strip_dirs=False,
            sort_by="cumulative",
        )
    if API_LOGGING_ENABLED:
        app.add_middleware(
            asgi_correlation_id.CorrelationIdMiddleware,
            header_name=CORRELATION_ID_HEADER,
            update_request_header=True,
            generator=lambda: uuid4().hex,
            validator=lambda a: True,
            transformer=lambda a: a,
        )
        if STRUCTURED_API_LOGGING:
            # Suppress uvicorn's default access log to avoid duplicate
            # unstructured entries — we replace it with a structured
            # access log middleware (see structured_access_log below).
            logging.getLogger("uvicorn.access").handlers = []
            logging.getLogger("uvicorn.access").propagate = False
    else:
        app.add_middleware(asgi_correlation_id.CorrelationIdMiddleware)

    if METRICS_ENABLED:

        @app.middleware("http")
        async def count_errors(request: Request, call_next):
            """Middleware to count errors.

            Args:
                request (Request): The incoming request.
                call_next (Callable): The next middleware or endpoint to call.

            Returns:
                Response: The response from the next middleware or endpoint.
            """
            response = await call_next(request)
            if self.model_manager.pingback and response.status_code >= 400:
                self.model_manager.num_errors += 1
            return response

    if not (LAMBDA or GCP_SERVERLESS):

        @app.get("/device/stats")
        def device_stats():
            not_configured_error_message = {
                "error": "Device statistics endpoint is not enabled.",
                "hint": "Mount the Docker socket and point its location when running the docker "
                "container to collect device stats "
                "(i.e. `docker run ... -v /var/run/docker.sock:/var/run/docker.sock "
                "-e DOCKER_SOCKET_PATH=/var/run/docker.sock ...`).",
            }
            if not DOCKER_SOCKET_PATH:
                return JSONResponse(
                    status_code=404,
                    content=not_configured_error_message,
                )
            if not is_docker_socket_mounted(docker_socket_path=DOCKER_SOCKET_PATH):
                return JSONResponse(
                    status_code=500,
                    content=not_configured_error_message,
                )
            container_stats = get_container_stats(
                docker_socket_path=DOCKER_SOCKET_PATH
            )
            return JSONResponse(status_code=200, content=container_stats)

    cached_api_keys = dict()

    if GCP_SERVERLESS:

        @app.middleware("http")
        async def check_authorization_serverless(request: Request, call_next):
            # exclusions
            skip_check = (
                request.method not in ["GET", "POST"]
                or request.url.path
                in [
                    "/",
                    "/docs",
                    "/info",
                    "/healthz",  # health check endpoint for liveness probe
                    "/readiness",
                    "/metrics",
                    "/openapi.json",  # needed for /docs and /redoc
                    "/model/registry",  # dont auth this route, usually not used on serverlerless, but queue based serverless uses it internally (not accessible from outside)
                ]
                or request.url.path.startswith("/static/")
                or request.url.path.startswith("/_next/")
            )

            # for these routes we only want to auth if dynamic python modules are provided
            if request.url.path in [
                "/workflows/blocks/describe",
                "/workflows/definition/schema",
            ]:
                if request.method == "GET":
                    skip_check = True

                elif (
                    get_content_type(request) == "application/json"
                    and int(request.headers.get("content-length", 0)) > 0
                ):
                    json_params = await request.json()
                    dynamic_blocks_definitions = json_params.get(
                        "dynamic_blocks_definitions", None
                    )
                    if not dynamic_blocks_definitions:
                        skip_check = True

            if skip_check:
                return await call_next(request)

            def _unauthorized_response(msg):
                return JSONResponse(
                    status_code=401,
                    content={
                        "status": 401,
                        "message": msg,
                    },
                )

            req_params = request.query_params
            json_params = dict()
            api_key = req_params.get("api_key", None)
            if (
                api_key is None
                and get_content_type(request) == "application/json"
                and int(request.headers.get("content-length", 0)) > 0
            ):
                # have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json
                try:
                    json_params = await request.json()
                except Exception:
                    pass
            api_key = json_params.get("api_key", api_key)

            if api_key is None:
                return _unauthorized_response("Unauthorized api_key")

            cache_entry = cached_api_keys.get(api_key)
            workspace_id = None
            if cache_entry and cache_entry[0] >= time.time():
                workspace_id = cache_entry[1]
            else:
                try:
                    workspace_id = await get_roboflow_workspace_async(
                        api_key=api_key
                    )
                    cached_api_keys[api_key] = (
                        time.time() + 3600,
                        workspace_id,
                    )  # expired after 1 hour
                except (RoboflowAPINotAuthorizedError, WorkspaceLoadError):
                    return _unauthorized_response("Unauthorized api_key")

            response = await call_next(request)
            if workspace_id:
                response.headers[WORKSPACE_ID_HEADER] = workspace_id
            return response

    if DEDICATED_DEPLOYMENT_WORKSPACE_URL:

        @app.middleware("http")
        async def check_authorization(request: Request, call_next):
            # exclusions
            skip_check = (
                request.method not in ["GET", "POST"]
                or request.url.path
                in [
                    "/",
                    "/docs",
                    "/redoc",
                    "/info",
                    "/healthz",  # health check endpoint for liveness probe
                    "/readiness",
                    "/metrics",
                    "/openapi.json",  # needed for /docs and /redoc
                ]
                or request.url.path.startswith("/static/")
                or request.url.path.startswith("/_next/")
            )
            if skip_check:
                return await call_next(request)

            def _unauthorized_response(msg):
                return JSONResponse(
                    status_code=401,
                    content={
                        "status": 401,
                        "message": msg,
                    },
                )

            # check api_key
            req_params = request.query_params
            json_params = dict()
            api_key = req_params.get("api_key", None)
            if (
                api_key is None
                and get_content_type(request) == "application/json"
                and int(request.headers.get("content-length", 0)) > 0
            ):
                # have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json
                try:
                    json_params = await request.json()
                except Exception:
                    pass
            api_key = json_params.get("api_key", api_key)

            if api_key is None:
                return _unauthorized_response("Unauthorized api_key")

            cache_entry = cached_api_keys.get(api_key)
            workspace_id = None
            if cache_entry and cache_entry[0] >= time.time():
                workspace_id = cache_entry[1]
            else:
                try:
                    if api_key is None:
                        workspace_id = None
                    else:
                        workspace_id = await get_roboflow_workspace_async(
                            api_key=api_key
                        )

                    if workspace_id != DEDICATED_DEPLOYMENT_WORKSPACE_URL:
                        return _unauthorized_response("Unauthorized api_key")

                    cached_api_keys[api_key] = (
                        time.time() + 3600,
                        workspace_id,
                    )  # expired after 1 hour
                except (RoboflowAPINotAuthorizedError, WorkspaceLoadError):
                    return _unauthorized_response("Unauthorized api_key")

            response = await call_next(request)
            if workspace_id:
                response.headers[WORKSPACE_ID_HEADER] = workspace_id
            return response

    @app.middleware("http")
    async def add_inference_engine_headers(request: Request, call_next):
        response = await call_next(request)
        inference_engine = (
            "inference-models" if USE_INFERENCE_MODELS else "old-inference"
        )
        response.headers["x-inference-engine"] = inference_engine
        return response

    @app.middleware("http")
    async def track_model_load(request: Request, call_next):
        load_collector = ModelLoadCollector()
        model_load_info.set(load_collector)
        ids_collector = RequestModelIds()
        request_model_ids.set(ids_collector)
        response = await call_next(request)
        if load_collector.has_data():
            total, detail = load_collector.summarize()
            response.headers[MODEL_COLD_START_HEADER] = "true"
            response.headers[MODEL_LOAD_TIME_HEADER] = str(total)
            if detail is not None:
                response.headers[MODEL_LOAD_DETAILS_HEADER] = detail
        else:
            response.headers[MODEL_COLD_START_HEADER] = "false"
        model_ids = ids_collector.get_ids()
        if model_ids:
            response.headers[MODEL_ID_HEADER] = ",".join(sorted(model_ids))
        wf_id = request_workflow_id.get(None)
        if wf_id:
            response.headers[WORKFLOW_ID_HEADER] = wf_id
        return response

    if API_LOGGING_ENABLED and STRUCTURED_API_LOGGING:

        @app.middleware("http")
        async def structured_access_log(request: Request, call_next):
            response = await call_next(request)
            log_fields = {
                "method": request.method,
                "path": request.url.path,
                "status_code": response.status_code,
            }

            # Read request_id and execution_id from response headers
            # instead of ContextVars — @app.middleware("http") uses
            # BaseHTTPMiddleware which runs the inner chain in a
            # separate asyncio task, so ContextVars set by inner
            # middlewares are not visible here.
            header_fields = {
                "request_id": CORRELATION_ID_HEADER,
                "processing_time": PROCESSING_TIME_HEADER,
                "model_cold_start": MODEL_COLD_START_HEADER,
                "model_load_time": MODEL_LOAD_TIME_HEADER,
                "model_id": MODEL_ID_HEADER,
                "workflow_id": WORKFLOW_ID_HEADER,
                "workspace_id": WORKSPACE_ID_HEADER,
            }
            if EXECUTION_ID_HEADER is not None:
                header_fields["execution_id"] = EXECUTION_ID_HEADER
            for field_name, header_name in header_fields.items():
                value = response.headers.get(header_name)
                if value is not None:
                    log_fields[field_name] = value

            logger.info(
                f"{request.method} {request.url.path} {response.status_code}",
                **log_fields,
            )
            return response

    self.app = app
    self.model_manager = model_manager
    self.stream_manager_client: Optional[StreamManagerClient] = None
    self.shared_thread_pool_executor: Optional[ThreadPoolExecutor] = None
    if HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED:
        self.shared_thread_pool_executor = ThreadPoolExecutor(
            max_workers=HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_WORKERS
        )

    if ENABLE_STREAM_API:
        operations_timeout = os.getenv("STREAM_MANAGER_OPERATIONS_TIMEOUT")
        if operations_timeout is not None:
            operations_timeout = float(operations_timeout)
        self.stream_manager_client = StreamManagerClient.init(
            host=os.getenv("STREAM_MANAGER_HOST", "127.0.0.1"),
            port=int(os.getenv("STREAM_MANAGER_PORT", "7070")),
            operations_timeout=operations_timeout,
        )
        self._instrumentator.set_stream_manager_client(self.stream_manager_client)

    def process_inference_request(
        inference_request: InferenceRequest,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
        **kwargs,
    ) -> InferenceResponse:
        """Processes an inference request by calling the appropriate model.

        Args:
            inference_request (InferenceRequest): The request containing model ID and other inference details.
            countinference (Optional[bool]): Whether to count inference for usage.
            service_secret (Optional[str]): The service secret.

        Returns:
            InferenceResponse: The response containing the inference results.
        """
        de_aliased_model_id = resolve_roboflow_model_alias(
            model_id=inference_request.model_id
        )
        self.model_manager.add_model(
            de_aliased_model_id,
            inference_request.api_key,
            countinference=countinference,
            service_secret=service_secret,
        )
        resp = self.model_manager.infer_from_request_sync(
            de_aliased_model_id, inference_request, **kwargs
        )
        return orjson_response(resp)

    def process_workflow_inference_request(
        workflow_request: WorkflowInferenceRequest,
        workflow_specification: dict,
        background_tasks: Optional[BackgroundTasks],
        profiler: WorkflowsProfiler,
    ) -> WorkflowInferenceResponse:
        if workflow_request.workflow_id:
            request_workflow_id.set(workflow_request.workflow_id)

        workflow_init_parameters = {
            "workflows_core.model_manager": model_manager,
            "workflows_core.api_key": workflow_request.api_key,
            "workflows_core.background_tasks": background_tasks,
        }
        execution_engine = ExecutionEngine.init(
            workflow_definition=workflow_specification,
            init_parameters=workflow_init_parameters,
            max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
            prevent_local_images_loading=True,
            profiler=profiler,
            executor=self.shared_thread_pool_executor,
            workflow_id=workflow_request.workflow_id,
        )
        is_preview = False
        if hasattr(workflow_request, "is_preview"):
            is_preview = workflow_request.is_preview
        workflow_results = execution_engine.run(
            runtime_parameters=workflow_request.inputs,
            serialize_results=True,
            _is_preview=is_preview,
        )
        with profiler.profile_execution_phase(
            name="workflow_results_filtering",
            categories=["inference_package_operation"],
        ):
            outputs = filter_out_unwanted_workflow_outputs(
                workflow_results=workflow_results,
                excluded_fields=workflow_request.excluded_fields,
            )
        profiler_trace = profiler.export_trace()
        response = WorkflowInferenceResponse(
            outputs=outputs,
            profiler_trace=profiler_trace,
        )
        return orjson_response(response=response)

    def load_core_model(
        inference_request: InferenceRequest,
        api_key: Optional[str] = None,
        core_model: str = None,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ) -> None:
        """Loads a core model (e.g., "clip" or "sam") into the model manager.

        Args:
            inference_request (InferenceRequest): The request containing version and other details.
            api_key (Optional[str]): The API key for the request.
            core_model (str): The core model type, e.g., "clip" or "sam".
            countinference (Optional[bool]): Whether to count inference or not.
            service_secret (Optional[str]): The service secret for the request.

        Returns:
            str: The core model ID.
        """
        if api_key:
            inference_request.api_key = api_key
        version_id_field = f"{core_model}_version_id"
        core_model_id = (
            f"{core_model}/{inference_request.__getattribute__(version_id_field)}"
        )
        self.model_manager.add_model(
            core_model_id,
            inference_request.api_key,
            endpoint_type=ModelEndpointType.CORE_MODEL,
            countinference=countinference,
            service_secret=service_secret,
        )
        return core_model_id

    load_clip_model = partial(load_core_model, core_model="clip")
    """Loads the CLIP model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The CLIP model ID.
    """

    load_pe_model = partial(load_core_model, core_model="perception_encoder")
    """Loads the Perception Encoder model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The Perception Encoder model ID.
    """

    load_sam_model = partial(load_core_model, core_model="sam")
    """Loads the SAM model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The SAM model ID.
    """
    load_sam2_model = partial(load_core_model, core_model="sam2")
    """Loads the SAM2 model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The SAM2 model ID.
    """

    load_gaze_model = partial(load_core_model, core_model="gaze")
    """Loads the GAZE model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The GAZE model ID.
    """

    load_doctr_model = partial(load_core_model, core_model="doctr")
    """Loads the DocTR model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The DocTR model ID.
    """

    load_easy_ocr_model = partial(load_core_model, core_model="easy_ocr")
    """Loads the EasyOCR model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The EasyOCR model ID.
    """

    load_paligemma_model = partial(load_core_model, core_model="paligemma")

    load_grounding_dino_model = partial(
        load_core_model, core_model="grounding_dino"
    )
    """Loads the Grounding DINO model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The Grounding DINO model ID.
    """

    load_yolo_world_model = partial(load_core_model, core_model="yolo_world")
    load_owlv2_model = partial(load_core_model, core_model="owlv2")
    """Loads the YOLO World model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The YOLO World model ID.
    """

    load_trocr_model = partial(load_core_model, core_model="trocr")
    """Loads the TrOCR model into the model manager.

    Args:
    Same as `load_core_model`.

    Returns:
    The TrOCR model ID.
    """

    @app.get(
        "/info",
        response_model=ServerVersionInfo,
        summary="Info",
        description="Get the server name and version number",
    )
    def root():
        """Endpoint to get the server name and version number.

        Returns:
            ServerVersionInfo: The server version information.
        """
        return ServerVersionInfo(
            name="Roboflow Inference Server",
            version=__version__,
            uuid=GLOBAL_INFERENCE_SERVER_ID,
        )

    @app.get(
        "/logs",
        summary="Get Recent Logs",
        description="Get recent application logs for debugging",
    )
    def get_logs(
        limit: Optional[int] = Query(
            100, description="Maximum number of log entries to return"
        ),
        level: Optional[str] = Query(
            None,
            description="Filter by log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
        ),
        since: Optional[str] = Query(
            None, description="Return logs since this ISO timestamp"
        ),
    ):
        """Get recent application logs from memory.

        Only available when ENABLE_IN_MEMORY_LOGS environment variable is set to 'true'.

        Args:
            limit: Maximum number of log entries (default 100)
            level: Filter by log level
            since: ISO timestamp to filter logs since

        Returns:
            List of log entries with timestamp, level, logger, and message
        """
        # Check if in-memory logging is enabled
        from inference.core.logging.memory_handler import (
            get_recent_logs,
            is_memory_logging_enabled,
        )

        if not is_memory_logging_enabled():
            raise HTTPException(
                status_code=404, detail="Logs endpoint not available"
            )

        try:
            logs = get_recent_logs(limit=limit or 100, level=level, since=since)
            return {"logs": logs, "total_count": len(logs)}
        except (ImportError, ModuleNotFoundError):
            raise HTTPException(
                status_code=500, detail="Logging system not properly initialized"
            )

    if not LAMBDA and GET_MODEL_REGISTRY_ENABLED:

        @app.get(
            "/model/registry",
            response_model=ModelsDescriptions,
            summary="Get model keys",
            description="Get the ID of each loaded model",
        )
        def registry():
            """Get the ID of each loaded model in the registry.

            Returns:
                ModelsDescriptions: The object containing models descriptions
            """
            logger.debug(f"Reached /model/registry")
            models_descriptions = self.model_manager.describe_models()
            return ModelsDescriptions.from_models_descriptions(
                models_descriptions=models_descriptions
            )

    # The current AWS Lambda authorizer only supports path parameters, therefore we can only use the legacy infer route. This case statement excludes routes which won't work for the current Lambda authorizer.
    if not (LAMBDA or GCP_SERVERLESS):

        @app.post(
            "/model/add",
            response_model=ModelsDescriptions,
            summary="Load a model",
            description="Load the model with the given model ID",
        )
        @with_route_exceptions
        def model_add(
            request: AddModelRequest,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Load the model with the given model ID into the model manager.

            Args:
                request (AddModelRequest): The request containing the model ID and optional API key.
                countinference (Optional[bool]): Whether to count inference or not.
                service_secret (Optional[str]): The service secret for the request.

            Returns:
                ModelsDescriptions: The object containing models descriptions
            """
            logger.debug(f"Reached /model/add")
            de_aliased_model_id = resolve_roboflow_model_alias(
                model_id=request.model_id
            )
            logger.info(f"Loading model: {de_aliased_model_id}")
            self.model_manager.add_model(
                de_aliased_model_id,
                request.api_key,
                countinference=countinference,
                service_secret=service_secret,
            )
            models_descriptions = self.model_manager.describe_models()
            return ModelsDescriptions.from_models_descriptions(
                models_descriptions=models_descriptions
            )

        @app.post(
            "/model/remove",
            response_model=ModelsDescriptions,
            summary="Remove a model",
            description="Remove the model with the given model ID",
        )
        @with_route_exceptions
        def model_remove(request: ClearModelRequest):
            """Remove the model with the given model ID from the model manager.

            Args:
                request (ClearModelRequest): The request containing the model ID to be removed.

            Returns:
                ModelsDescriptions: The object containing models descriptions
            """
            logger.debug(f"Reached /model/remove")
            de_aliased_model_id = resolve_roboflow_model_alias(
                model_id=request.model_id
            )
            self.model_manager.remove(de_aliased_model_id)
            models_descriptions = self.model_manager.describe_models()
            return ModelsDescriptions.from_models_descriptions(
                models_descriptions=models_descriptions
            )

        @app.post(
            "/model/clear",
            response_model=ModelsDescriptions,
            summary="Remove all models",
            description="Remove all loaded models",
        )
        @with_route_exceptions
        def model_clear():
            """Remove all loaded models from the model manager.

            Returns:
                ModelsDescriptions: The object containing models descriptions
            """
            logger.debug(f"Reached /model/clear")
            self.model_manager.clear()
            models_descriptions = self.model_manager.describe_models()
            return ModelsDescriptions.from_models_descriptions(
                models_descriptions=models_descriptions
            )

    # these NEW endpoints need authentication protection
    if not LAMBDA and not GCP_SERVERLESS:

        @app.post(
            "/infer/object_detection",
            response_model=Union[
                ObjectDetectionInferenceResponse,
                List[ObjectDetectionInferenceResponse],
                StubResponse,
            ],
            summary="Object detection infer",
            description="Run inference with the specified object detection model",
            response_model_exclude_none=True,
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_object_detection(
            inference_request: ObjectDetectionInferenceRequest,
            background_tasks: BackgroundTasks,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Run inference with the specified object detection model.

            Args:
                inference_request (ObjectDetectionInferenceRequest): The request containing the necessary details for object detection.
                background_tasks: (BackgroundTasks) pool of fastapi background tasks

            Returns:
                Union[ObjectDetectionInferenceResponse, List[ObjectDetectionInferenceResponse]]: The response containing the inference results.
            """
            logger.debug(f"Reached /infer/object_detection")
            return process_inference_request(
                inference_request,
                active_learning_eligible=True,
                background_tasks=background_tasks,
                countinference=countinference,
                service_secret=service_secret,
            )

        @app.post(
            "/infer/instance_segmentation",
            response_model=Union[
                InstanceSegmentationInferenceResponse, StubResponse
            ],
            summary="Instance segmentation infer",
            description="Run inference with the specified instance segmentation model",
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_instance_segmentation(
            inference_request: InstanceSegmentationInferenceRequest,
            background_tasks: BackgroundTasks,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Run inference with the specified instance segmentation model.

            Args:
                inference_request (InstanceSegmentationInferenceRequest): The request containing the necessary details for instance segmentation.
                background_tasks: (BackgroundTasks) pool of fastapi background tasks

            Returns:
                InstanceSegmentationInferenceResponse: The response containing the inference results.
            """
            logger.debug(f"Reached /infer/instance_segmentation")
            return process_inference_request(
                inference_request,
                active_learning_eligible=True,
                background_tasks=background_tasks,
                countinference=countinference,
                service_secret=service_secret,
            )

        @app.post(
            "/infer/semantic_segmentation",
            response_model=Union[
                SemanticSegmentationInferenceResponse, StubResponse
            ],
            summary="Semantic segmentation infer",
            description="Run inference with the specified semantic segmentation model",
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_semantic_segmentation(
            inference_request: SemanticSegmentationInferenceRequest,
            background_tasks: BackgroundTasks,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Run inference with the specified semantic segmentation model.

            Args:
                inference_request (SemanticSegmentationInferenceRequest): The request containing the necessary details for semantic segmentation.
                background_tasks: (BackgroundTasks) pool of fastapi background tasks

            Returns:
                SemanticSegmentationInferenceResponse: The response containing the inference results.
            """
            logger.debug(f"Reached /infer/semantic_segmentation")
            return process_inference_request(
                inference_request,
                active_learning_eligible=True,
                background_tasks=background_tasks,
                countinference=countinference,
                service_secret=service_secret,
            )

        @app.post(
            "/infer/classification",
            response_model=Union[
                ClassificationInferenceResponse,
                MultiLabelClassificationInferenceResponse,
                StubResponse,
            ],
            summary="Classification infer",
            description="Run inference with the specified classification model",
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_classification(
            inference_request: ClassificationInferenceRequest,
            background_tasks: BackgroundTasks,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Run inference with the specified classification model.

            Args:
                inference_request (ClassificationInferenceRequest): The request containing the necessary details for classification.
                background_tasks: (BackgroundTasks) pool of fastapi background tasks

            Returns:
                Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results.
            """
            logger.debug(f"Reached /infer/classification")
            return process_inference_request(
                inference_request,
                active_learning_eligible=True,
                background_tasks=background_tasks,
                countinference=countinference,
                service_secret=service_secret,
            )

        @app.post(
            "/infer/keypoints_detection",
            response_model=Union[KeypointsDetectionInferenceResponse, StubResponse],
            summary="Keypoints detection infer",
            description="Run inference with the specified keypoints detection model",
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_keypoints(
            inference_request: KeypointsDetectionInferenceRequest,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """Run inference with the specified keypoints detection model.

            Args:
                inference_request (KeypointsDetectionInferenceRequest): The request containing the necessary details for keypoints detection.

            Returns:
                Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results.
            """
            logger.debug(f"Reached /infer/keypoints_detection")
            return process_inference_request(
                inference_request,
                countinference=countinference,
                service_secret=service_secret,
            )

        if LMM_ENABLED or MOONDREAM2_ENABLED:

            @app.post(
                "/infer/lmm",
                response_model=Union[
                    LMMInferenceResponse,
                    List[LMMInferenceResponse],
                    StubResponse,
                ],
                summary="Large multi-modal model infer",
                description="Run inference with the specified large multi-modal model",
                response_model_exclude_none=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_lmm(
                inference_request: LMMInferenceRequest,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified large multi-modal model.

                Args:
                    inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference.

                Returns:
                    Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results.
                """
                logger.debug(f"Reached /infer/lmm")
                return process_inference_request(
                    inference_request,
                    countinference=countinference,
                    service_secret=service_secret,
                )

            @app.post(
                "/infer/lmm/{model_id:path}",
                response_model=Union[
                    LMMInferenceResponse,
                    List[LMMInferenceResponse],
                    StubResponse,
                ],
                summary="Large multi-modal model infer with model ID in path",
                description="Run inference with the specified large multi-modal model. Model ID is specified in the URL path (can contain slashes).",
                response_model_exclude_none=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def infer_lmm_with_model_id(
                model_id: str,
                inference_request: LMMInferenceRequest,
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Run inference with the specified large multi-modal model.

                The model_id can be specified in the URL path. If model_id is also provided
                in the request body, it must match the path parameter.

                Args:
                    model_id (str): The model identifier from the URL path.
                    inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference.

                Returns:
                    Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results.

                Raises:
                    HTTPException: If model_id in path and request body don't match.
                """
                logger.debug(f"Reached /infer/lmm/{model_id}")

                # Validate model_id consistency between path and request body
                if (
                    inference_request.model_id is not None
                    and inference_request.model_id != model_id
                ):
                    raise HTTPException(
                        status_code=400,
                        detail=f"Model ID mismatch: path specifies '{model_id}' but request body specifies '{inference_request.model_id}'",
                    )

                # Set the model_id from path if not in request body
                inference_request.model_id = model_id

                return process_inference_request(
                    inference_request,
                    countinference=countinference,
                    service_secret=service_secret,
                )

    if not DISABLE_WORKFLOW_ENDPOINTS:

        @app.post(
            "/{workspace_name}/workflows/{workflow_id}/describe_interface",
            response_model=DescribeInterfaceResponse,
            summary="Endpoint to describe interface of predefined workflow",
            description="Checks Roboflow API for workflow definition, once acquired - describes workflow inputs and outputs",
        )
        @with_route_exceptions
        def describe_predefined_workflow_interface(
            workspace_name: str,
            workflow_id: str,
            workflow_request: PredefinedWorkflowDescribeInterfaceRequest,
        ) -> DescribeInterfaceResponse:
            workflow_specification = get_workflow_specification(
                api_key=workflow_request.api_key,
                workspace_id=workspace_name,
                workflow_id=workflow_id,
                use_cache=workflow_request.use_cache,
                workflow_version_id=workflow_request.workflow_version_id,
            )
            return handle_describe_workflows_interface(
                definition=workflow_specification,
            )

        @app.post(
            "/workflows/describe_interface",
            response_model=DescribeInterfaceResponse,
            summary="Endpoint to describe interface of workflow given in request",
            description="Parses workflow definition and retrieves describes inputs and outputs",
        )
        @with_route_exceptions
        def describe_workflow_interface(
            workflow_request: WorkflowSpecificationDescribeInterfaceRequest,
        ) -> DescribeInterfaceResponse:
            return handle_describe_workflows_interface(
                definition=workflow_request.specification,
            )

        @app.post(
            "/{workspace_name}/workflows/{workflow_id}",
            response_model=WorkflowInferenceResponse,
            summary="Endpoint to run predefined workflow",
            description="Checks Roboflow API for workflow definition, once acquired - parses and executes injecting runtime parameters from request body",
        )
        @app.post(
            "/infer/workflows/{workspace_name}/{workflow_id}",
            response_model=WorkflowInferenceResponse,
            summary="[LEGACY] Endpoint to run predefined workflow",
            description="Checks Roboflow API for workflow definition, once acquired - parses and executes injecting runtime parameters from request body. This endpoint is deprecated and will be removed end of Q2 2024",
            deprecated=True,
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_from_predefined_workflow(
            workspace_name: str,
            workflow_id: str,
            workflow_request: PredefinedWorkflowInferenceRequest,
            background_tasks: BackgroundTasks,
        ) -> WorkflowInferenceResponse:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            if ENABLE_WORKFLOWS_PROFILING and workflow_request.enable_profiling:
                profiler = BaseWorkflowsProfiler.init(
                    max_runs_in_buffer=WORKFLOWS_PROFILER_BUFFER_SIZE,
                )
            else:
                profiler = NullWorkflowsProfiler.init()
            with profiler.profile_execution_phase(
                name="workflow_definition_fetching",
                categories=["inference_package_operation"],
            ):
                workflow_specification = get_workflow_specification(
                    api_key=workflow_request.api_key,
                    workspace_id=workspace_name,
                    workflow_id=workflow_id,
                    use_cache=workflow_request.use_cache,
                    workflow_version_id=workflow_request.workflow_version_id,
                )
            if not workflow_request.workflow_id:
                workflow_request.workflow_id = workflow_id
            if not workflow_specification.get("id"):
                logger.warning(
                    "Internal workflow ID missing in specification for '%s'",
                    workflow_id,
                )
            return process_workflow_inference_request(
                workflow_request=workflow_request,
                workflow_specification=workflow_specification,
                background_tasks=(
                    background_tasks if not (LAMBDA or GCP_SERVERLESS) else None
                ),
                profiler=profiler,
            )

        @app.post(
            "/workflows/run",
            response_model=WorkflowInferenceResponse,
            summary="Endpoint to run workflow specification provided in payload",
            description="Parses and executes workflow specification, injecting runtime parameters from request body.",
        )
        @app.post(
            "/infer/workflows",
            response_model=WorkflowInferenceResponse,
            summary="[LEGACY] Endpoint to run workflow specification provided in payload",
            description="Parses and executes workflow specification, injecting runtime parameters from request body. This endpoint is deprecated and will be removed end of Q2 2024.",
            deprecated=True,
        )
        @with_route_exceptions
        @usage_collector("request")
        def infer_from_workflow(
            workflow_request: WorkflowSpecificationInferenceRequest,
            background_tasks: BackgroundTasks,
        ) -> WorkflowInferenceResponse:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            if ENABLE_WORKFLOWS_PROFILING and workflow_request.enable_profiling:
                profiler = BaseWorkflowsProfiler.init(
                    max_runs_in_buffer=WORKFLOWS_PROFILER_BUFFER_SIZE,
                )
            else:
                profiler = NullWorkflowsProfiler.init()
            return process_workflow_inference_request(
                workflow_request=workflow_request,
                workflow_specification=workflow_request.specification,
                background_tasks=(
                    background_tasks if not (LAMBDA or GCP_SERVERLESS) else None
                ),
                profiler=profiler,
            )

        @app.get(
            "/workflows/execution_engine/versions",
            response_model=ExecutionEngineVersions,
            summary="Returns available Execution Engine versions sorted from oldest to newest",
            description="Returns available Execution Engine versions sorted from oldest to newest",
        )
        @with_route_exceptions
        def get_execution_engine_versions() -> ExecutionEngineVersions:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            versions = get_available_versions()
            return ExecutionEngineVersions(versions=versions)

        @app.get(
            "/workflows/blocks/describe",
            response_model=WorkflowsBlocksDescription,
            summary="[LEGACY] Endpoint to get definition of workflows blocks that are accessible",
            description="Endpoint provides detailed information about workflows building blocks that are "
            "accessible in the inference server. This information could be used to programmatically "
            "build / display workflows.",
            deprecated=True,
        )
        @with_route_exceptions
        def describe_workflows_blocks(
            request: Request,
        ) -> Union[WorkflowsBlocksDescription, Response]:
            result = handle_describe_workflows_blocks_request()
            return gzip_response_if_requested(request=request, response=result)

        @app.post(
            "/workflows/blocks/describe",
            response_model=WorkflowsBlocksDescription,
            summary="[EXPERIMENTAL] Endpoint to get definition of workflows blocks that are accessible",
            description="Endpoint provides detailed information about workflows building blocks that are "
            "accessible in the inference server. This information could be used to programmatically "
            "build / display workflows. Additionally - in request body one can specify list of "
            "dynamic blocks definitions which will be transformed into blocks and used to generate "
            "schemas and definitions of connections",
        )
        @with_route_exceptions
        def describe_workflows_blocks(
            request: Request,
            request_payload: Optional[DescribeBlocksRequest] = None,
        ) -> Union[WorkflowsBlocksDescription, Response]:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            dynamic_blocks_definitions = None
            requested_execution_engine_version = None
            api_key = None
            if request_payload is not None:
                dynamic_blocks_definitions = (
                    request_payload.dynamic_blocks_definitions
                )
                requested_execution_engine_version = (
                    request_payload.execution_engine_version
                )
                api_key = request_payload.api_key or request.query_params.get(
                    "api_key", None
                )
            result = handle_describe_workflows_blocks_request(
                dynamic_blocks_definitions=dynamic_blocks_definitions,
                requested_execution_engine_version=requested_execution_engine_version,
                api_key=api_key,
            )
            return gzip_response_if_requested(request=request, response=result)

        @app.get(
            "/workflows/definition/schema",
            response_model=WorkflowsBlocksSchemaDescription,
            summary="Endpoint to fetch the workflows block schema",
            description="Endpoint to fetch the schema of all available blocks. This information can be "
            "used to validate workflow definitions and suggest syntax in the JSON editor.",
        )
        @with_route_exceptions
        def get_workflow_schema(
            request: Request,
        ) -> WorkflowsBlocksSchemaDescription:
            result = get_workflow_schema_description()
            return gzip_response_if_requested(request, response=result)

        @app.post(
            "/workflows/blocks/dynamic_outputs",
            response_model=List[OutputDefinition],
            summary="[EXPERIMENTAL] Endpoint to get definition of dynamic output for workflow step",
            description="Endpoint to be used when step outputs can be discovered only after "
            "filling manifest with data.",
        )
        @with_route_exceptions
        def get_dynamic_block_outputs(
            step_manifest: Dict[str, Any],
        ) -> List[OutputDefinition]:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            # Potentially TODO: dynamic blocks do not support dynamic outputs, but if it changes
            # we need to provide dynamic blocks manifests here
            dummy_workflow_definition = {
                "version": "1.0",
                "inputs": [],
                "steps": [step_manifest],
                "outputs": [],
            }
            available_blocks = load_workflow_blocks()
            parsed_definition = parse_workflow_definition(
                raw_workflow_definition=dummy_workflow_definition,
                available_blocks=available_blocks,
            )
            parsed_manifest = parsed_definition.steps[0]
            return parsed_manifest.get_actual_outputs()

        @app.post(
            "/workflows/validate",
            response_model=WorkflowValidationStatus,
            summary="[EXPERIMENTAL] Endpoint to validate",
            description="Endpoint provides a way to check validity of JSON workflow definition.",
        )
        @with_route_exceptions
        def validate_workflow(
            specification: dict,
            api_key: Optional[str] = Query(
                None,
                description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
            ),
        ) -> WorkflowValidationStatus:
            # TODO: get rid of async: https://github.com/roboflow/inference/issues/569
            step_execution_mode = StepExecutionMode(WORKFLOWS_STEP_EXECUTION_MODE)
            workflow_init_parameters = {
                "workflows_core.model_manager": model_manager,
                "workflows_core.api_key": api_key,
                "workflows_core.background_tasks": None,
                "workflows_core.step_execution_mode": step_execution_mode,
            }
            _ = ExecutionEngine.init(
                workflow_definition=specification,
                init_parameters=workflow_init_parameters,
                max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
                prevent_local_images_loading=True,
            )
            return WorkflowValidationStatus(status="ok")

    if WEBRTC_WORKER_ENABLED:

        @app.post(
            "/initialise_webrtc_worker",
            response_model=InitializeWebRTCResponse,
            summary="[EXPERIMENTAL] Establishes WebRTC peer connection and processes video stream in spawned process or modal function",
            description="[EXPERIMENTAL] Establishes WebRTC peer connection and processes video stream in spawned process or modal function",
        )
        @with_route_exceptions_async
        async def initialise_webrtc_worker(
            request: WebRTCWorkerRequest,
            r: Request,
        ) -> InitializeWebRTCResponse:
            if str(r.headers.get("origin")).lower() == BUILDER_ORIGIN.lower():
                if re.search(
                    r"^https://[^.]+\.roboflow\.[^./]+/", str(r.url).lower()
                ):
                    request.is_preview = True

            logger.debug("Received initialise_webrtc_worker request")
            worker_result: WebRTCWorkerResult = await start_worker(
                webrtc_request=request,
            )
            if worker_result.exception_type is not None:
                if worker_result.exception_type == "WorkflowSyntaxError":
                    raise WorkflowSyntaxError(
                        public_message=worker_result.error_message,
                        context=worker_result.error_context,
                        inner_error=worker_result.inner_error,
                    )
                if worker_result.exception_type == "WorkflowError":
                    raise WorkflowError(
                        public_message=worker_result.error_message,
                        context=worker_result.error_context,
                    )
                expected_exceptions = {
                    "Exception": Exception,
                    "KeyError": KeyError,
                    "MissingApiKeyError": MissingApiKeyError,
                    "NotImplementedError": NotImplementedError,
                    "RoboflowAPINotAuthorizedError": RoboflowAPINotAuthorizedError,
                    "RoboflowAPINotNotFoundError": RoboflowAPINotNotFoundError,
                    "ValidationError": ValidationError,
                    "WebRTCConfigurationError": WebRTCConfigurationError,
                }
                exc = expected_exceptions.get(
                    worker_result.exception_type, Exception
                )(worker_result.error_message)
                logger.error(
                    f"Initialise webrtc worker failed with %s: %s",
                    worker_result.exception_type,
                    worker_result.error_message,
                )
                raise exc
            logger.debug("Returning initialise_webrtc_worker response")
            return InitializeWebRTCResponse(
                context=CommandContext(),
                status=OperationStatus.SUCCESS,
                sdp=worker_result.answer.sdp,
                type=worker_result.answer.type,
            )

    if ENABLE_STREAM_API:

        @app.get(
            "/inference_pipelines/list",
            response_model=ListPipelinesResponse,
            summary="[EXPERIMENTAL] List active InferencePipelines",
            description="[EXPERIMENTAL] Listing all active InferencePipelines processing videos",
        )
        @with_route_exceptions_async
        async def list_pipelines(_: Request) -> ListPipelinesResponse:
            return await self.stream_manager_client.list_pipelines()

        @app.get(
            "/inference_pipelines/{pipeline_id}/status",
            response_model=InferencePipelineStatusResponse,
            summary="[EXPERIMENTAL] Get status of InferencePipeline",
            description="[EXPERIMENTAL] Get status of InferencePipeline",
        )
        @with_route_exceptions_async
        async def get_status(pipeline_id: str) -> InferencePipelineStatusResponse:
            return await self.stream_manager_client.get_status(
                pipeline_id=pipeline_id
            )

        @app.post(
            "/inference_pipelines/initialise",
            response_model=CommandResponse,
            summary="[EXPERIMENTAL] Starts new InferencePipeline",
            description="[EXPERIMENTAL] Starts new InferencePipeline",
        )
        @with_route_exceptions_async
        async def initialise(request: InitialisePipelinePayload) -> CommandResponse:
            return await self.stream_manager_client.initialise_pipeline(
                initialisation_request=request
            )

        @app.post(
            "/inference_pipelines/initialise_webrtc",
            response_model=InitializeWebRTCPipelineResponse,
            summary="[EXPERIMENTAL] Establishes WebRTC peer connection and starts new InferencePipeline consuming video track",
            description="[EXPERIMENTAL] Establishes WebRTC peer connection and starts new InferencePipeline consuming video track",
        )
        @with_route_exceptions_async
        async def initialise_webrtc_inference_pipeline(
            request: InitialiseWebRTCPipelinePayload,
        ) -> CommandResponse:
            logger.debug("Received initialise webrtc inference pipeline request")
            resp = await self.stream_manager_client.initialise_webrtc_pipeline(
                initialisation_request=request
            )
            logger.debug("Returning initialise webrtc inference pipeline response")
            return resp

        @app.post(
            "/inference_pipelines/{pipeline_id}/pause",
            response_model=CommandResponse,
            summary="[EXPERIMENTAL] Pauses the InferencePipeline",
            description="[EXPERIMENTAL] Pauses the InferencePipeline",
        )
        @with_route_exceptions_async
        async def pause(pipeline_id: str) -> CommandResponse:
            return await self.stream_manager_client.pause_pipeline(
                pipeline_id=pipeline_id
            )

        @app.post(
            "/inference_pipelines/{pipeline_id}/resume",
            response_model=CommandResponse,
            summary="[EXPERIMENTAL] Resumes the InferencePipeline",
            description="[EXPERIMENTAL] Resumes the InferencePipeline",
        )
        @with_route_exceptions_async
        async def resume(pipeline_id: str) -> CommandResponse:
            return await self.stream_manager_client.resume_pipeline(
                pipeline_id=pipeline_id
            )

        @app.post(
            "/inference_pipelines/{pipeline_id}/terminate",
            response_model=CommandResponse,
            summary="[EXPERIMENTAL] Terminates the InferencePipeline",
            description="[EXPERIMENTAL] Terminates the InferencePipeline",
        )
        @with_route_exceptions_async
        async def terminate(pipeline_id: str) -> CommandResponse:
            return await self.stream_manager_client.terminate_pipeline(
                pipeline_id=pipeline_id
            )

        @app.get(
            "/inference_pipelines/{pipeline_id}/consume",
            response_model=ConsumePipelineResponse,
            summary="[EXPERIMENTAL] Consumes InferencePipeline result",
            description="[EXPERIMENTAL] Consumes InferencePipeline result",
        )
        @with_route_exceptions_async
        async def consume(
            pipeline_id: str,
            request: Optional[ConsumeResultsPayload] = None,
        ) -> ConsumePipelineResponse:
            if request is None:
                request = ConsumeResultsPayload()
            return await self.stream_manager_client.consume_pipeline_result(
                pipeline_id=pipeline_id,
                excluded_fields=request.excluded_fields,
            )

    class ModelInitState:
        """Class to track model initialization state."""

        def __init__(self):
            self.is_ready = False
            self.lock = Lock()  # For thread-safe updates
            self.initialization_errors = []  # Track errors per model

    model_init_state = ModelInitState()

    should_preload = PRELOAD_MODELS or PINNED_MODELS
    if not should_preload:
        model_init_state.is_ready = True

    # Enable preloading models at startup
    if should_preload:

        def initialize_models(state: ModelInitState):
            """Perform asynchronous initialization tasks to load models."""

            def load_model(model_id):
                t_start = time.perf_counter()
                de_aliased = resolve_roboflow_model_alias(model_id=model_id)
                logger.info(
                    f"Preload: starting model load for '{model_id}' (resolved: '{de_aliased}')"
                )
                try:
                    self.model_manager.add_model(
                        de_aliased,
                        PRELOAD_API_KEY,
                    )
                    load_time = time.perf_counter() - t_start
                    logger.info(
                        f"Preload: model '{model_id}' loaded successfully in {load_time:.1f}s"
                    )
                except Exception as e:
                    load_time = time.perf_counter() - t_start
                    error_msg = f"Preload: error loading model '{model_id}' after {load_time:.1f}s: {e}"
                    logger.error(error_msg)
                    with state.lock:
                        state.initialization_errors.append((model_id, str(e)))
                    return

                # Pin if this model is in PINNED_MODELS
                if (
                    PINNED_MODELS
                    and model_id in PINNED_MODELS
                    and hasattr(self.model_manager, "pin_model")
                ):
                    self.model_manager.pin_model(de_aliased)

            all_models = list(
                dict.fromkeys((PRELOAD_MODELS or []) + (PINNED_MODELS or []))
            )
            if all_models:
                # Create tasks for each model to be loaded
                model_loading_executor = ThreadPoolExecutor(max_workers=2)
                loaded_futures: List[Tuple[str, Future]] = []
                for model_id in all_models:
                    future = model_loading_executor.submit(
                        load_model, model_id=model_id
                    )
                    loaded_futures.append((model_id, future))

                for model_id, future in loaded_futures:
                    try:
                        future.result(timeout=300)
                    except (
                        TimeoutError,
                        CancelledError,
                        concurrent.futures.TimeoutError,
                    ):
                        state.initialization_errors.append(
                            (
                                model_id,
                                "Could not finalise model loading before timeout",
                            )
                        )
                        future.cancel()
                    except Exception as e:
                        logger.error(
                            f"Preload: unexpected error for model '{model_id}': {e}"
                        )
                        with state.lock:
                            state.initialization_errors.append((model_id, str(e)))

            # Update the readiness state in a thread-safe manner
            with state.lock:
                state.is_ready = True

        @app.on_event("startup")
        def startup_model_init():
            """Initialize the models on startup."""
            startup_thread = Thread(
                target=initialize_models, args=(model_init_state,), daemon=True
            )
            startup_thread.start()
            logger.info("Model initialization started in the background.")

    # Attach health/readiness endpoints
    @app.get("/readiness", status_code=200)
    def readiness(
        state: ModelInitState = Depends(lambda: model_init_state),
    ):
        """Readiness endpoint for Kubernetes readiness probe."""
        with state.lock:
            if state.is_ready:
                return {"status": "ready"}
            else:
                return JSONResponse(
                    content={"status": "not ready"}, status_code=503
                )

    @app.get("/healthz", status_code=200)
    def healthz():
        """Health endpoint for Kubernetes liveness probe."""
        return {"status": "healthy"}

    if CORE_MODELS_ENABLED:
        if CORE_MODEL_CLIP_ENABLED:

            @app.post(
                "/clip/embed_image",
                response_model=ClipEmbeddingResponse,
                summary="CLIP Image Embeddings",
                description="Run the Open AI CLIP model to embed image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def clip_embed_image(
                inference_request: ClipImageEmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the OpenAI CLIP model.

                Args:
                    inference_request (ClipImageEmbeddingRequest): The request containing the image to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    ClipEmbeddingResponse: The response containing the embedded image.
                """
                logger.debug(f"Reached /clip/embed_image")
                clip_model_id = load_clip_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    clip_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(clip_model_id, actor)
                return response

            @app.post(
                "/clip/embed_text",
                response_model=ClipEmbeddingResponse,
                summary="CLIP Text Embeddings",
                description="Run the Open AI CLIP model to embed text data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def clip_embed_text(
                inference_request: ClipTextEmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds text data using the OpenAI CLIP model.

                Args:
                    inference_request (ClipTextEmbeddingRequest): The request containing the text to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    ClipEmbeddingResponse: The response containing the embedded text.
                """
                logger.debug(f"Reached /clip/embed_text")
                clip_model_id = load_clip_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    clip_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(clip_model_id, actor)
                return response

            @app.post(
                "/clip/compare",
                response_model=ClipCompareResponse,
                summary="CLIP Compare",
                description="Run the Open AI CLIP model to compute similarity scores.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def clip_compare(
                inference_request: ClipCompareRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Computes similarity scores using the OpenAI CLIP model.

                Args:
                    inference_request (ClipCompareRequest): The request containing the data to be compared.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    ClipCompareResponse: The response containing the similarity scores.
                """
                logger.debug(f"Reached /clip/compare")
                clip_model_id = load_clip_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    clip_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(clip_model_id, actor, n=2)
                return response

        if CORE_MODEL_PE_ENABLED:

            @app.post(
                "/perception_encoder/embed_image",
                response_model=PerceptionEncoderEmbeddingResponse,
                summary="PE Image Embeddings",
                description="Run the Meta Perception Encoder model to embed image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def pe_embed_image(
                inference_request: PerceptionEncoderImageEmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the Perception Encoder PE model.

                Args:
                    inference_request (PerceptionEncoderImageEmbeddingRequest): The request containing the image to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    PerceptionEncoderEmbeddingResponse: The response containing the embedded image.
                """
                logger.debug(f"Reached /perception_encoder/embed_image")
                pe_model_id = load_pe_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    pe_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(pe_model_id, actor)
                return response

            @app.post(
                "/perception_encoder/embed_text",
                response_model=PerceptionEncoderEmbeddingResponse,
                summary="Perception Encoder Text Embeddings",
                description="Run the Meta Perception Encoder model to embed text data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def pe_embed_text(
                inference_request: PerceptionEncoderTextEmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds text data using the Meta Perception Encoder model.

                Args:
                    inference_request (PerceptionEncoderTextEmbeddingRequest): The request containing the text to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    PerceptionEncoderEmbeddingResponse: The response containing the embedded text.
                """
                logger.debug(f"Reached /perception_encoder/embed_text")
                pe_model_id = load_pe_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    pe_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(pe_model_id, actor)
                return response

            @app.post(
                "/perception_encoder/compare",
                response_model=PerceptionEncoderCompareResponse,
                summary="Perception Encoder Compare",
                description="Run the Meta Perception Encoder model to compute similarity scores.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def pe_compare(
                inference_request: PerceptionEncoderCompareRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Computes similarity scores using the Meta Perception Encoder model.

                Args:
                    inference_request (PerceptionEncoderCompareRequest): The request containing the data to be compared.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    PerceptionEncoderCompareResponse: The response containing the similarity scores.
                """
                logger.debug(f"Reached /perception_encoder/compare")
                pe_model_id = load_pe_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    pe_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(pe_model_id, actor, n=2)
                return response

        if CORE_MODEL_GROUNDINGDINO_ENABLED:

            @app.post(
                "/grounding_dino/infer",
                response_model=ObjectDetectionInferenceResponse,
                summary="Grounding DINO inference.",
                description="Run the Grounding DINO zero-shot object detection model.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def grounding_dino_infer(
                inference_request: GroundingDINOInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the Grounding DINO model.

                Args:
                    inference_request GroundingDINOInferenceRequest): The request containing the image on which to run object detection.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    ObjectDetectionInferenceResponse: The object detection response.
                """
                logger.debug(f"Reached /grounding_dino/infer")
                grounding_dino_model_id = load_grounding_dino_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    grounding_dino_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(grounding_dino_model_id, actor)
                return response

        if CORE_MODEL_YOLO_WORLD_ENABLED:

            @app.post(
                "/yolo_world/infer",
                response_model=ObjectDetectionInferenceResponse,
                summary="YOLO-World inference.",
                description="Run the YOLO-World zero-shot object detection model.",
                response_model_exclude_none=True,
            )
            @with_route_exceptions
            @usage_collector("request")
            def yolo_world_infer(
                inference_request: YOLOWorldInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Runs the YOLO-World zero-shot object detection model.

                Args:
                    inference_request (YOLOWorldInferenceRequest): The request containing the image on which to run object detection.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    ObjectDetectionInferenceResponse: The object detection response.
                """
                logger.debug(f"Reached /yolo_world/infer. Loading model")
                yolo_world_model_id = load_yolo_world_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                logger.debug("YOLOWorld model loaded. Staring the inference.")
                response = self.model_manager.infer_from_request_sync(
                    yolo_world_model_id, inference_request
                )
                logger.debug("YOLOWorld prediction available.")
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(yolo_world_model_id, actor)
                    logger.debug("Usage of YOLOWorld denoted.")
                return response

        if CORE_MODEL_DOCTR_ENABLED:

            @app.post(
                "/doctr/ocr",
                response_model=Union[
                    OCRInferenceResponse, List[OCRInferenceResponse]
                ],
                summary="DocTR OCR response",
                description="Run the DocTR OCR model to retrieve text in an image.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def doctr_retrieve_text(
                inference_request: DoctrOCRInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the DocTR model.

                Args:
                    inference_request (M.DoctrOCRInferenceRequest): The request containing the image from which to retrieve text.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    OCRInferenceResponse: The response containing the embedded image.
                """
                logger.debug(f"Reached /doctr/ocr")
                doctr_model_id = load_doctr_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    doctr_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(doctr_model_id, actor)
                return orjson_response_keeping_parent_id(response)

        if CORE_MODEL_EASYOCR_ENABLED:

            @app.post(
                "/easy_ocr/ocr",
                response_model=Union[
                    OCRInferenceResponse, List[OCRInferenceResponse]
                ],
                summary="EasyOCR OCR response",
                description="Run the EasyOCR model to retrieve text in an image.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def easy_ocr_retrieve_text(
                inference_request: EasyOCRInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the EasyOCR model.

                Args:
                    inference_request (EasyOCRInferenceRequest): The request containing the image from which to retrieve text.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    OCRInferenceResponse: The response containing the embedded image.
                """
                logger.debug(f"Reached /easy_ocr/ocr")
                easy_ocr_model_id = load_easy_ocr_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    easy_ocr_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(easy_ocr_model_id, actor)
                return orjson_response_keeping_parent_id(response)

        if CORE_MODEL_SAM_ENABLED:

            @app.post(
                "/sam/embed_image",
                response_model=SamEmbeddingResponse,
                summary="SAM Image Embeddings",
                description="Run the Meta AI Segmant Anything Model to embed image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam_embed_image(
                inference_request: SamEmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the Meta AI Segmant Anything Model (SAM).

                Args:
                    inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.SamEmbeddingResponse or Response: The response containing the embedded image.
                """
                logger.debug(f"Reached /sam/embed_image")
                sam_model_id = load_sam_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                model_response = self.model_manager.infer_from_request_sync(
                    sam_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(sam_model_id, actor)
                if inference_request.format == "binary":
                    return Response(
                        content=model_response.embeddings,
                        headers={"Content-Type": "application/octet-stream"},
                    )
                return model_response

            @app.post(
                "/sam/segment_image",
                response_model=SamSegmentationResponse,
                summary="SAM Image Segmentation",
                description="Run the Meta AI Segmant Anything Model to generate segmenations for image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam_segment_image(
                inference_request: SamSegmentationRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Generates segmentations for image data using the Meta AI Segmant Anything Model (SAM).

                Args:
                    inference_request (SamSegmentationRequest): The request containing the image to be segmented.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.SamSegmentationResponse or Response: The response containing the segmented image.
                """
                logger.debug(f"Reached /sam/segment_image")
                sam_model_id = load_sam_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                model_response = self.model_manager.infer_from_request_sync(
                    sam_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(sam_model_id, actor)
                if inference_request.format == "binary":
                    return Response(
                        content=model_response,
                        headers={"Content-Type": "application/octet-stream"},
                    )
                return model_response

        if CORE_MODEL_SAM2_ENABLED:

            @app.post(
                "/sam2/embed_image",
                response_model=Sam2EmbeddingResponse,
                summary="SAM2 Image Embeddings",
                description="Run the Meta AI Segment Anything 2 Model to embed image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam2_embed_image(
                inference_request: Sam2EmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the Meta AI Segment Anything Model (SAM).

                Args:
                    inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.Sam2EmbeddingResponse or Response: The response affirming the image has been embedded
                """
                logger.debug(f"Reached /sam2/embed_image")
                sam2_model_id = load_sam2_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                model_response = self.model_manager.infer_from_request_sync(
                    sam2_model_id, inference_request
                )
                return model_response

            @app.post(
                "/sam2/segment_image",
                response_model=Sam2SegmentationResponse,
                summary="SAM2 Image Segmentation",
                description="Run the Meta AI Segment Anything 2 Model to generate segmenations for image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam2_segment_image(
                inference_request: Sam2SegmentationRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Generates segmentations for image data using the Meta AI Segment Anything Model (SAM).

                Args:
                    inference_request (Sam2SegmentationRequest): The request containing the image to be segmented.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.SamSegmentationResponse or Response: The response containing the segmented image.
                """
                logger.debug(f"Reached /sam2/segment_image")
                sam2_model_id = load_sam2_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                model_response = self.model_manager.infer_from_request_sync(
                    sam2_model_id, inference_request
                )
                if inference_request.format == "binary":
                    return Response(
                        content=model_response,
                        headers={"Content-Type": "application/octet-stream"},
                    )
                return model_response

        if CORE_MODEL_SAM3_ENABLED and not GCP_SERVERLESS:

            @app.post(
                "/sam3/embed_image",
                response_model=Sam3EmbeddingResponse,
                summary="Seg preview Image Embeddings",
                description="Run the  Model to embed image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam3_embed_image(
                inference_request: Sam2EmbeddingRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                logger.debug(f"Reached /sam3/embed_image")

                if SAM3_EXEC_MODE == "remote":
                    raise HTTPException(
                        status_code=501,
                        detail="SAM3 embedding is not supported in remote execution mode.",
                    )

                self.model_manager.add_model(
                    "sam3/sam3_interactive",
                    api_key=api_key,
                    endpoint_type=ModelEndpointType.CORE_MODEL,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                model_response = self.model_manager.infer_from_request_sync(
                    "sam3/sam3_interactive", inference_request
                )
                return model_response

        if CORE_MODEL_SAM3_ENABLED:

            @app.post(
                "/sam3/concept_segment",
                response_model=Sam3SegmentationResponse,
                summary="SAM3 PCS (promptable concept segmentation)",
                description="Run the SAM3 PCS (promptable concept segmentation) to generate segmentations for image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam3_segment_image(
                inference_request: Sam3SegmentationRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                if not SAM3_FINE_TUNED_MODELS_ENABLED:
                    if not inference_request.model_id.startswith("sam3/"):
                        raise HTTPException(
                            status_code=501,
                            detail="Fine-tuned SAM3 models are not supported on this deployment. Please use a workflow or self-host the server.",
                        )

                if SAM3_EXEC_MODE == "remote":
                    endpoint = f"{API_BASE_URL}/inferenceproxy/seg-preview"

                    # Construct payload for remote API
                    # The remote API expects:
                    # {
                    #     "image": {"type": "base64", "value": ...},
                    #     "prompts": [{"type": "text", "text": ...}, ...],
                    #     "output_prob_thresh": ...
                    # }

                    # Extract prompts from request
                    http_prompts = []
                    for prompt in inference_request.prompts:
                        p_dict = prompt.dict(exclude_none=True)
                        # Ensure type is set if missing (default to text if text is present)
                        if "type" not in p_dict:
                            if "text" in p_dict:
                                p_dict["type"] = "text"
                        http_prompts.append(p_dict)

                    # Prepare image
                    # inference_request.image is InferenceRequestImage
                    if inference_request.image.type == "base64":
                        http_image = {
                            "type": "base64",
                            "value": inference_request.image.value,
                        }
                    elif inference_request.image.type == "url":
                        http_image = {
                            "type": "url",
                            "value": inference_request.image.value,
                        }
                    elif inference_request.image.type == "numpy":
                        # Numpy not supported for remote proxy easily without serialization,
                        # but InferenceRequestImage usually comes as base64/url in HTTP API.
                        # If it is numpy, we might need to handle it, but for now assume base64/url.
                        # If it's numpy, it's likely from internal call, but this is HTTP API.
                        http_image = {
                            "type": "numpy",
                            "value": inference_request.image.value,
                        }
                    else:
                        http_image = {
                            "type": inference_request.image.type,
                            "value": inference_request.image.value,
                        }

                    payload = {
                        "image": http_image,
                        "prompts": http_prompts,
                        "output_prob_thresh": inference_request.output_prob_thresh,
                    }

                    try:
                        headers = {"Content-Type": "application/json"}
                        if ROBOFLOW_INTERNAL_SERVICE_NAME:
                            headers["X-Roboflow-Internal-Service-Name"] = (
                                ROBOFLOW_INTERNAL_SERVICE_NAME
                            )
                        if ROBOFLOW_INTERNAL_SERVICE_SECRET:
                            headers["X-Roboflow-Internal-Service-Secret"] = (
                                ROBOFLOW_INTERNAL_SERVICE_SECRET
                            )

                        headers = build_roboflow_api_headers(
                            explicit_headers=headers
                        )

                        response = requests.post(
                            f"{endpoint}?api_key={api_key}",
                            json=payload,
                            headers=headers,
                            timeout=60,
                        )
                        response.raise_for_status()
                        resp_json = response.json()

                        # The remote API returns the same structure as Sam3SegmentationResponse
                        return Sam3SegmentationResponse(**resp_json)

                    except Exception as e:
                        logger.error(f"SAM3 remote request failed: {e}")
                        raise HTTPException(
                            status_code=500,
                            detail=f"SAM3 remote request failed: {str(e)}",
                        )

                if inference_request.model_id.startswith("sam3/"):
                    self.model_manager.add_model(
                        inference_request.model_id,
                        api_key=api_key,
                        endpoint_type=ModelEndpointType.CORE_MODEL,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                else:
                    self.model_manager.add_model(
                        inference_request.model_id,
                        api_key=api_key,
                        endpoint_type=ModelEndpointType.ORT,
                        countinference=countinference,
                        service_secret=service_secret,
                    )

                model_response = self.model_manager.infer_from_request_sync(
                    inference_request.model_id, inference_request
                )
                if inference_request.format == "binary":
                    return Response(
                        content=model_response,
                        headers={"Content-Type": "application/octet-stream"},
                    )
                return model_response

            @app.post(
                "/sam3/visual_segment",
                response_model=Sam2SegmentationResponse,
                summary="SAM3 PVS (promptable visual segmentation)",
                description="Run the SAM3 PVS (promptable visual segmentation) to generate segmentations for image data.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam3_visual_segment(
                inference_request: Sam2SegmentationRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                logger.debug(f"Reached /sam3/visual_segment")

                if SAM3_EXEC_MODE == "remote":
                    endpoint = f"{API_BASE_URL}/inferenceproxy/sam3-pvs"

                    http_image = {
                        "type": inference_request.image.type,
                        "value": inference_request.image.value,
                    }

                    prompts_data = (
                        inference_request.prompts.dict(exclude_none=True)
                        if inference_request.prompts
                        else None
                    )

                    payload = {
                        "image": http_image,
                        "prompts": prompts_data,
                        "multimask_output": inference_request.multimask_output,
                    }

                    try:
                        headers = {"Content-Type": "application/json"}
                        if ROBOFLOW_INTERNAL_SERVICE_NAME:
                            headers["X-Roboflow-Internal-Service-Name"] = (
                                ROBOFLOW_INTERNAL_SERVICE_NAME
                            )
                        if ROBOFLOW_INTERNAL_SERVICE_SECRET:
                            headers["X-Roboflow-Internal-Service-Secret"] = (
                                ROBOFLOW_INTERNAL_SERVICE_SECRET
                            )

                        headers = build_roboflow_api_headers(
                            explicit_headers=headers
                        )

                        response = requests.post(
                            f"{endpoint}?api_key={api_key}",
                            json=payload,
                            headers=headers,
                            timeout=60,
                        )
                        response.raise_for_status()
                        resp_json = response.json()

                        return Sam2SegmentationResponse(**resp_json)

                    except Exception as e:
                        logger.error(
                            f"SAM3 visual_segment remote request failed: {e}"
                        )
                        raise HTTPException(
                            status_code=500,
                            detail=f"SAM3 visual_segment remote request failed: {str(e)}",
                        )

                self.model_manager.add_model(
                    "sam3/sam3_interactive",
                    api_key=api_key,
                    endpoint_type=ModelEndpointType.CORE_MODEL,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                model_response = self.model_manager.infer_from_request_sync(
                    "sam3/sam3_interactive", inference_request
                )
                return model_response

        if CORE_MODEL_SAM3_ENABLED and not GCP_SERVERLESS:

            @app.post(
                "/sam3_3d/infer",
                summary="SAM3 3D Object Generation",
                description="Generate 3D meshes and Gaussian splatting from 2D images with mask prompts.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def sam3_3d_infer(
                inference_request: Sam3_3D_Objects_InferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """Generate 3D meshes and Gaussian splatting from 2D images with mask prompts.

                Args:
                    inference_request (Sam3_3D_Objects_InferenceRequest): The request containing
                        the image and mask input for 3D generation.
                    api_key (Optional[str]): Roboflow API Key for artifact retrieval.

                Returns:
                    dict: Response containing base64-encoded 3D outputs:
                        - mesh_glb: Scene mesh in GLB format (base64)
                        - gaussian_ply: Combined Gaussian splatting in PLY format (base64)
                        - objects: List of individual objects with their 3D data
                        - time: Inference time in seconds
                """
                logger.debug("Reached /sam3_3d/infer")
                model_id = inference_request.model_id or "sam3-3d-objects"

                self.model_manager.add_model(
                    model_id,
                    api_key=api_key,
                    endpoint_type=ModelEndpointType.CORE_MODEL,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                model_response = self.model_manager.infer_from_request_sync(
                    model_id, inference_request
                )

                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(model_id, actor)

                # Convert bytes to base64 for JSON serialization
                def encode_bytes(data):
                    if data is None:
                        return None
                    return base64.b64encode(data).decode("utf-8")

                objects_list = []
                for obj in model_response.objects:
                    objects_list.append(
                        {
                            "mesh_glb": encode_bytes(obj.mesh_glb),
                            "gaussian_ply": encode_bytes(obj.gaussian_ply),
                            "metadata": {
                                "rotation": obj.metadata.rotation,
                                "translation": obj.metadata.translation,
                                "scale": obj.metadata.scale,
                            },
                        }
                    )

                return {
                    "mesh_glb": encode_bytes(model_response.mesh_glb),
                    "gaussian_ply": encode_bytes(model_response.gaussian_ply),
                    "objects": objects_list,
                    "time": model_response.time,
                }

        if CORE_MODEL_OWLV2_ENABLED:

            @app.post(
                "/owlv2/infer",
                response_model=ObjectDetectionInferenceResponse,
                summary="Owlv2 image prompting",
                description="Run the google owlv2 model to few-shot object detect",
            )
            @with_route_exceptions
            @usage_collector("request")
            def owlv2_infer(
                inference_request: OwlV2InferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Embeds image data using the Meta AI Segmant Anything Model (SAM).

                Args:
                    inference_request (SamEmbeddingRequest): The request containing the image to be embedded.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.Sam2EmbeddingResponse or Response: The response affirming the image has been embedded
                """
                logger.debug(f"Reached /owlv2/infer")
                owl2_model_id = load_owlv2_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                model_response = self.model_manager.infer_from_request_sync(
                    owl2_model_id, inference_request
                )
                return model_response

        if CORE_MODEL_GAZE_ENABLED:

            @app.post(
                "/gaze/gaze_detection",
                response_model=List[GazeDetectionInferenceResponse],
                summary="Gaze Detection",
                description="Run the gaze detection model to detect gaze.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def gaze_detection(
                inference_request: GazeDetectionInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Detect gaze using the gaze detection model.

                Args:
                    inference_request (M.GazeDetectionRequest): The request containing the image to be detected.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    M.GazeDetectionResponse: The response containing all the detected faces and the corresponding gazes.
                """
                logger.debug(f"Reached /gaze/gaze_detection")
                gaze_model_id = load_gaze_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    gaze_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(gaze_model_id, actor)
                return response

        if DEPTH_ESTIMATION_ENABLED:

            @app.post(
                "/infer/depth-estimation",
                response_model=DepthEstimationResponse,
                summary="Depth Estimation",
                description="Run the depth estimation model to generate a depth map.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def depth_estimation(
                inference_request: DepthEstimationRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Generate a depth map using the depth estimation model.

                Args:
                    inference_request (DepthEstimationRequest): The request containing the image to estimate depth for.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    DepthEstimationResponse: The response containing the normalized depth map and optional visualization.
                """
                logger.debug(f"Reached /infer/depth-estimation")
                depth_model_id = inference_request.model_id
                self.model_manager.add_model(
                    depth_model_id,
                    inference_request.api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    depth_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(depth_model_id, actor)

                # Extract data from nested response structure
                depth_data = response.response
                depth_response = DepthEstimationResponse(
                    normalized_depth=depth_data["normalized_depth"].tolist(),
                    image=depth_data["image"].base64_image,
                )
                return depth_response

        if CORE_MODEL_TROCR_ENABLED:

            @app.post(
                "/ocr/trocr",
                response_model=OCRInferenceResponse,
                summary="TrOCR OCR response",
                description="Run the TrOCR model to retrieve text in an image.",
            )
            @with_route_exceptions
            @usage_collector("request")
            def trocr_retrieve_text(
                inference_request: TrOCRInferenceRequest,
                request: Request,
                api_key: Optional[str] = Query(
                    None,
                    description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
                ),
                countinference: Optional[bool] = None,
                service_secret: Optional[str] = None,
            ):
                """
                Retrieves text from image data using the TrOCR model.

                Args:
                    inference_request (TrOCRInferenceRequest): The request containing the image from which to retrieve text.
                    api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                    request (Request, default Body()): The HTTP request.

                Returns:
                    OCRInferenceResponse: The response containing the retrieved text.
                """
                logger.debug(f"Reached /trocr/ocr")
                trocr_model_id = load_trocr_model(
                    inference_request,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )
                response = self.model_manager.infer_from_request_sync(
                    trocr_model_id, inference_request
                )
                if LAMBDA:
                    actor = request.scope["aws.event"]["requestContext"][
                        "authorizer"
                    ]["lambda"]["actor"]
                    trackUsage(trocr_model_id, actor)
                return orjson_response_keeping_parent_id(response)

    if not (LAMBDA or GCP_SERVERLESS):

        @app.get(
            "/notebook/start",
            summary="Jupyter Lab Server Start",
            description="Starts a jupyter lab server for running development code",
        )
        @with_route_exceptions
        def notebook_start(browserless: bool = False):
            """Starts a jupyter lab server for running development code.

            Args:
                inference_request (NotebookStartRequest): The request containing the necessary details for starting a jupyter lab server.
                background_tasks: (BackgroundTasks) pool of fastapi background tasks

            Returns:
                NotebookStartResponse: The response containing the URL of the jupyter lab server.
            """
            logger.debug(f"Reached /notebook/start")
            if NOTEBOOK_ENABLED:
                start_notebook()
                if browserless:
                    return {
                        "success": True,
                        "message": f"Jupyter Lab server started at http://localhost:{NOTEBOOK_PORT}?token={NOTEBOOK_PASSWORD}",
                    }
                else:
                    sleep(2)
                    return RedirectResponse(
                        f"http://localhost:{NOTEBOOK_PORT}/lab/tree/quickstart.ipynb?token={NOTEBOOK_PASSWORD}"
                    )
            else:
                if browserless:
                    return {
                        "success": False,
                        "message": "Notebook server is not enabled. Enable notebooks via the NOTEBOOK_ENABLED environment variable.",
                    }
                else:
                    return RedirectResponse(f"/notebook-instructions.html")

    if ENABLE_BUILDER:
        from inference.core.interfaces.http.builder.routes import (
            router as builder_router,
        )

        # Allow CORS on builder API and workflow endpoints needed by the builder UI
        # Enables Private Network Access for Chrome 142+ (local development)
        app.add_middleware(
            PathAwareCORSMiddleware,
            match_paths=r"^/(build/api|workflows/).*",
            allow_origins=[BUILDER_ORIGIN],
            allow_methods=["*"],
            allow_headers=["*"],
            allow_credentials=True,
            allow_private_network=True,
        )

        # Attach all routes from builder to the /build prefix
        app.include_router(builder_router, prefix="/build", tags=["builder"])

    if LEGACY_ROUTE_ENABLED:
        # Legacy object detection inference path for backwards compatibility
        @app.get(
            "/{dataset_id}/{version_id:str}",
            # Order matters in this response model Union. It will use the first matching model. For example, Object Detection Inference Response is a subset of Instance segmentation inference response, so instance segmentation must come first in order for the matching logic to work.
            response_model=Union[
                InstanceSegmentationInferenceResponse,
                KeypointsDetectionInferenceResponse,
                ObjectDetectionInferenceResponse,
                ClassificationInferenceResponse,
                MultiLabelClassificationInferenceResponse,
                SemanticSegmentationInferenceResponse,
                StubResponse,
                Any,
            ],
            response_model_exclude_none=True,
        )
        @app.post(
            "/{dataset_id}/{version_id:str}",
            # Order matters in this response model Union. It will use the first matching model. For example, Object Detection Inference Response is a subset of Instance segmentation inference response, so instance segmentation must come first in order for the matching logic to work.
            response_model=Union[
                InstanceSegmentationInferenceResponse,
                KeypointsDetectionInferenceResponse,
                ObjectDetectionInferenceResponse,
                ClassificationInferenceResponse,
                MultiLabelClassificationInferenceResponse,
                SemanticSegmentationInferenceResponse,
                StubResponse,
                Any,
            ],
            response_model_exclude_none=True,
        )
        @with_route_exceptions
        @usage_collector("request")
        def legacy_infer_from_request(
            background_tasks: BackgroundTasks,
            request: Request,
            request_body: Annotated[
                Optional[Union[bytes, UploadFile]],
                Depends(parse_body_content_for_legacy_request_handler),
            ],
            dataset_id: str = Path(
                description="ID of a Roboflow dataset corresponding to the model to use for inference OR workspace ID"
            ),
            version_id: str = Path(
                description="ID of a Roboflow dataset version corresponding to the model to use for inference OR model ID"
            ),
            api_key: Optional[str] = Query(
                None,
                description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval",
            ),
            confidence: float = Query(
                0.4,
                description="The confidence threshold used to filter out predictions",
            ),
            keypoint_confidence: float = Query(
                0.0,
                description="The confidence threshold used to filter out keypoints that are not visible based on model confidence",
            ),
            format: str = Query(
                "json",
                description="One of 'json' or 'image'. If 'json' prediction data is return as a JSON string. If 'image' prediction data is visualized and overlayed on the original input image.",
            ),
            image: Optional[str] = Query(
                None,
                description="The publically accessible URL of an image to use for inference.",
            ),
            image_type: Optional[str] = Query(
                "base64",
                description="One of base64 or numpy. Note, numpy input is not supported for Roboflow Hosted Inference.",
            ),
            labels: Optional[bool] = Query(
                False,
                description="If true, labels will be include in any inference visualization.",
            ),
            mask_decode_mode: Optional[str] = Query(
                "accurate",
                description="One of 'accurate' or 'fast'. If 'accurate' the mask will be decoded using the original image size. If 'fast' the mask will be decoded using the original mask size. 'accurate' is slower but more accurate.",
            ),
            tradeoff_factor: Optional[float] = Query(
                0.0,
                description="The amount to tradeoff between 0='fast' and 1='accurate'",
            ),
            max_detections: int = Query(
                300,
                description="The maximum number of detections to return. This is used to limit the number of predictions returned by the model. The model may return more predictions than this number, but only the top `max_detections` predictions will be returned.",
            ),
            overlap: float = Query(
                0.3,
                description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS",
            ),
            stroke: int = Query(
                1, description="The stroke width used when visualizing predictions"
            ),
            countinference: Optional[bool] = Query(
                True,
                description="If false, does not track inference against usage.",
                include_in_schema=False,
            ),
            service_secret: Optional[str] = Query(
                None,
                description="Shared secret used to authenticate requests to the inference server from internal services (e.g. to allow disabling inference usage tracking via the `countinference` query parameter)",
                include_in_schema=False,
            ),
            disable_preproc_auto_orient: Optional[bool] = Query(
                False, description="If true, disables automatic image orientation"
            ),
            disable_preproc_contrast: Optional[bool] = Query(
                False, description="If true, disables automatic contrast adjustment"
            ),
            disable_preproc_grayscale: Optional[bool] = Query(
                False,
                description="If true, disables automatic grayscale conversion",
            ),
            disable_preproc_static_crop: Optional[bool] = Query(
                False, description="If true, disables automatic static crop"
            ),
            disable_active_learning: Optional[bool] = Query(
                default=False,
                description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)",
            ),
            active_learning_target_dataset: Optional[str] = Query(
                default=None,
                description="Parameter to be used when Active Learning data registration should happen against different dataset than the one pointed by model_id",
            ),
            source: Optional[str] = Query(
                "external",
                description="The source of the inference request",
            ),
            source_info: Optional[str] = Query(
                "external",
                description="The detailed source information of the inference request",
            ),
            disable_model_monitoring: Optional[bool] = Query(
                False,
                description="If true, disables model monitoring for this request",
                include_in_schema=False,
            ),
        ):
            """
            Legacy inference endpoint for object detection, instance segmentation, and classification.

            Args:
                background_tasks: (BackgroundTasks) pool of fastapi background tasks
                dataset_id (str): ID of a Roboflow dataset corresponding to the model to use for inference OR workspace ID
                version_id (str): ID of a Roboflow dataset version corresponding to the model to use for inference OR model ID
                api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval.
                # Other parameters described in the function signature...

            Returns:
                Union[InstanceSegmentationInferenceResponse, KeypointsDetectionInferenceRequest, ObjectDetectionInferenceResponse, ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse, SemanticSegmentationInferenceResponse, Any]: The response containing the inference results.
            """
            logger.debug(
                f"Reached legacy route /:dataset_id/:version_id with {dataset_id}/{version_id}"
            )
            model_id = f"{dataset_id}/{version_id}"
            if confidence >= 1:
                confidence /= 100
            if confidence < CONFIDENCE_LOWER_BOUND_OOM_PREVENTION:
                # allowing lower confidence results in RAM usage explosion
                confidence = CONFIDENCE_LOWER_BOUND_OOM_PREVENTION

            if overlap >= 1:
                overlap /= 100
            if image is not None:
                request_image = InferenceRequestImage(type="url", value=image)
            else:
                if "Content-Type" not in request.headers:
                    raise ContentTypeMissing(
                        f"Request must include a Content-Type header"
                    )
                if isinstance(request_body, UploadFile):
                    base64_image_str = request_body.file.read()
                    base64_image_str = base64.b64encode(base64_image_str)
                    request_image = InferenceRequestImage(
                        type="base64", value=base64_image_str.decode("ascii")
                    )
                elif isinstance(request_body, bytes):
                    request_image = InferenceRequestImage(
                        type=image_type, value=request_body
                    )
                elif request_body is None:
                    raise InputImageLoadError(
                        message="Image not found in request body.",
                        public_message="Image not found in request body.",
                    )
                else:
                    raise ContentTypeInvalid(
                        f"Invalid Content-Type: {request.headers['Content-Type']}"
                    )

            if not countinference and service_secret != ROBOFLOW_SERVICE_SECRET:
                raise MissingServiceSecretError(
                    "Service secret is required to disable inference usage tracking"
                )
            if LAMBDA:
                logger.debug("request.scope: %s", request.scope)
                request_model_id = (
                    request.scope["aws.event"]["requestContext"]["authorizer"][
                        "lambda"
                    ]["model"]["endpoint"]
                    .replace("--", "/")
                    .replace("rf-", "")
                    .replace("nu-", "")
                )
                actor = request.scope["aws.event"]["requestContext"]["authorizer"][
                    "lambda"
                ]["actor"]
                if countinference:
                    trackUsage(request_model_id, actor)
                else:
                    if service_secret != ROBOFLOW_SERVICE_SECRET:
                        raise MissingServiceSecretError(
                            "Service secret is required to disable inference usage tracking"
                        )
                    logger.info("Not counting inference for usage")
            else:
                request_model_id = model_id
            logger.debug(
                f"State of model registry: {self.model_manager.describe_models()}"
            )
            self.model_manager.add_model(
                request_model_id,
                api_key,
                model_id_alias=model_id,
                countinference=countinference,
                service_secret=service_secret,
            )

            task_type = self.model_manager.get_task_type(model_id, api_key=api_key)
            inference_request_type = ObjectDetectionInferenceRequest
            args = dict()
            if task_type == "instance-segmentation":
                inference_request_type = InstanceSegmentationInferenceRequest
                args = {
                    "mask_decode_mode": mask_decode_mode,
                    "tradeoff_factor": tradeoff_factor,
                }
            elif task_type == "classification":
                inference_request_type = ClassificationInferenceRequest
            elif task_type == "keypoint-detection":
                inference_request_type = KeypointsDetectionInferenceRequest
                args = {"keypoint_confidence": keypoint_confidence}
            elif task_type == "semantic-segmentation":
                inference_request_type = SemanticSegmentationInferenceRequest
            inference_request = inference_request_type(
                api_key=api_key,
                model_id=model_id,
                image=request_image,
                confidence=confidence,
                iou_threshold=overlap,
                max_detections=max_detections,
                visualization_labels=labels,
                visualization_stroke_width=stroke,
                visualize_predictions=(
                    format == "image" or format == "image_and_json"
                ),
                disable_preproc_auto_orient=disable_preproc_auto_orient,
                disable_preproc_contrast=disable_preproc_contrast,
                disable_preproc_grayscale=disable_preproc_grayscale,
                disable_preproc_static_crop=disable_preproc_static_crop,
                disable_active_learning=disable_active_learning,
                active_learning_target_dataset=active_learning_target_dataset,
                source=source,
                source_info=source_info,
                usage_billable=countinference,
                disable_model_monitoring=disable_model_monitoring,
                **args,
            )
            inference_response = self.model_manager.infer_from_request_sync(
                inference_request.model_id,
                inference_request,
                active_learning_eligible=True,
                background_tasks=background_tasks,
            )
            logger.debug("Response ready.")
            if format == "image":
                return Response(
                    content=inference_response.visualization,
                    media_type="image/jpeg",
                )
            else:
                return orjson_response(inference_response)

    if not (LAMBDA or GCP_SERVERLESS):
        # Legacy clear cache endpoint for backwards compatibility
        @app.get("/clear_cache", response_model=str)
        def legacy_clear_cache():
            """
            Clears the model cache.

            This endpoint provides a way to clear the cache of loaded models.

            Returns:
                str: A string indicating that the cache has been cleared.
            """
            logger.debug(f"Reached /clear_cache")
            model_clear()
            return "Cache Cleared"

        # Legacy add model endpoint for backwards compatibility
        @app.get("/start/{dataset_id}/{version_id}")
        def model_add_legacy(
            dataset_id: str,
            version_id: str,
            api_key: str = None,
            countinference: Optional[bool] = None,
            service_secret: Optional[str] = None,
        ):
            """
            Starts a model inference session.

            This endpoint initializes and starts an inference session for the specified model version.

            Args:
                dataset_id (str): ID of a Roboflow dataset corresponding to the model.
                version_id (str): ID of a Roboflow dataset version corresponding to the model.
                api_key (str, optional): Roboflow API Key for artifact retrieval.
                countinference (Optional[bool]): Whether to count inference or not.
                service_secret (Optional[str]): The service secret for the request.

            Returns:
                JSONResponse: A response object containing the status and a success message.
            """
            logger.debug(
                f"Reached /start/{dataset_id}/{version_id} with {dataset_id}/{version_id}"
            )
            model_id = f"{dataset_id}/{version_id}"
            self.model_manager.add_model(
                model_id,
                api_key,
                countinference=countinference,
                service_secret=service_secret,
            )

            return JSONResponse(
                {
                    "status": 200,
                    "message": "inference session started from local memory.",
                }
            )

    if not ENABLE_DASHBOARD:

        @app.get("/dashboard.html")
        @app.head("/dashboard.html")
        async def dashboard_guard():
            return Response(status_code=404)

    @app.exception_handler(InputImageLoadError)
    async def unicorn_exception_handler(request: Request, exc: InputImageLoadError):
        return JSONResponse(
            status_code=400,
            content={
                "message": f"Could not load input image. Cause: {exc.get_public_error_details()}"
            },
        )

    app.mount(
        "/",
        StaticFiles(directory="./inference/landing/out", html=True),
        name="root",
    )

Functions

load_gaze_model

load_gaze_model(inference_request, api_key=None)

Loads the gaze detection model.

Parameters:

Name Type Description Default
inference_request GazeDetectionInferenceRequest

The inference request.

required
api_key Optional[str], default None

The Roboflow API key.

None

Returns:

Name Type Description
str str

The model ID.

Source code in inference/core/interfaces/http/http_api.py
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
def load_gaze_model(
    inference_request: GazeDetectionInferenceRequest, api_key: Optional[str] = None
) -> str:
    """Loads the gaze detection model.

    Args:
        inference_request (GazeDetectionInferenceRequest): The inference request.
        api_key (Optional[str], default None): The Roboflow API key.

    Returns:
        str: The model ID.
    """
    return inference_request.model_id

core/interfaces/http/middlewares

inference.core.interfaces.http.middlewares.cors

Classes

PathAwareCORSMiddleware

Bases: CORSMiddleware

Extends Starlette's CORSMiddleware to allow specifying a regex of paths that this middleware should apply to. If 'match_paths' is given, only requests matching that regex will have CORS headers applied.

Also supports Private Network Access (PNA) for local development, allowing requests from public websites to localhost.

Source code in inference/core/interfaces/http/middlewares/cors.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PathAwareCORSMiddleware(StarletteCORSMiddleware):
    """
    Extends Starlette's CORSMiddleware to allow specifying a regex of paths that
    this middleware should apply to.
    If 'match_paths' is given, only requests matching that regex will have CORS
    headers applied.

    Also supports Private Network Access (PNA) for local development, allowing
    requests from public websites to localhost.
    """

    def __init__(
        self,
        app: ASGIApp,
        match_paths: str | None = None,
        allow_origins: typing.Sequence[str] = (),
        allow_methods: typing.Sequence[str] = ("GET",),
        allow_headers: typing.Sequence[str] = (),
        allow_credentials: bool = False,
        allow_origin_regex: str | None = None,
        expose_headers: typing.Sequence[str] = (),
        max_age: int = 600,
        allow_private_network: bool = False,
    ) -> None:
        super().__init__(
            app=app,
            allow_origins=allow_origins,
            allow_methods=allow_methods,
            allow_headers=allow_headers,
            allow_credentials=allow_credentials,
            allow_origin_regex=allow_origin_regex,
            expose_headers=expose_headers,
            max_age=max_age,
        )
        self.match_paths_regex = re.compile(match_paths) if match_paths else None
        self.allow_private_network = allow_private_network
        # Store these for PNA preflight handling (not exposed by parent class)
        self._max_age = max_age
        self._allow_methods = allow_methods
        self._allow_headers = allow_headers
        self._allow_credentials = allow_credentials

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """
        Only apply the CORS logic if the path matches self.match_paths_regex
        (when provided). Otherwise, just call the wrapped 'app'.
        """
        # If it's not an HTTP request, skip the CORS processing:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        # If match_paths was supplied, check if the current path matches
        if self.match_paths_regex is not None:
            path = scope.get("path", "")
            if not self.match_paths_regex.match(path):
                # If it does NOT match, just run the app without CORS
                await self.app(scope, receive, send)
                return

        # Handle Private Network Access preflight requests
        if self.allow_private_network:
            headers = Headers(scope=scope)
            if (
                scope["method"] == "OPTIONS"
                and "access-control-request-private-network" in headers
            ):
                await self._handle_pna_preflight(scope, receive, send, headers)
                return

        # If we got here, apply the normal Starlette CORSMiddleware behavior
        await super().__call__(scope, receive, send)

    async def _handle_pna_preflight(
        self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
    ) -> None:
        """
        Handle preflight requests that include Private Network Access header.
        """
        origin = request_headers.get("origin", "")
        if self.is_allowed_origin(origin=origin):
            response_headers = {
                "access-control-allow-origin": origin,
                "access-control-allow-private-network": "true",
                "access-control-allow-methods": ", ".join(self._allow_methods),
                "access-control-max-age": str(self._max_age),
            }
            if self._allow_headers and "*" not in self._allow_headers:
                response_headers["access-control-allow-headers"] = ", ".join(
                    self._allow_headers
                )
            elif "*" in self._allow_headers:
                requested_headers = request_headers.get(
                    "access-control-request-headers", ""
                )
                if requested_headers:
                    response_headers["access-control-allow-headers"] = requested_headers
            if self._allow_credentials:
                response_headers["access-control-allow-credentials"] = "true"
            response = PlainTextResponse(
                "OK", status_code=200, headers=response_headers
            )
        else:
            response = PlainTextResponse(
                "Disallowed CORS origin", status_code=400, headers={"vary": "Origin"}
            )
        await response(scope, receive, send)
Functions
__call__ async
__call__(scope, receive, send)

Only apply the CORS logic if the path matches self.match_paths_regex (when provided). Otherwise, just call the wrapped 'app'.

Source code in inference/core/interfaces/http/middlewares/cors.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    """
    Only apply the CORS logic if the path matches self.match_paths_regex
    (when provided). Otherwise, just call the wrapped 'app'.
    """
    # If it's not an HTTP request, skip the CORS processing:
    if scope["type"] != "http":
        await self.app(scope, receive, send)
        return

    # If match_paths was supplied, check if the current path matches
    if self.match_paths_regex is not None:
        path = scope.get("path", "")
        if not self.match_paths_regex.match(path):
            # If it does NOT match, just run the app without CORS
            await self.app(scope, receive, send)
            return

    # Handle Private Network Access preflight requests
    if self.allow_private_network:
        headers = Headers(scope=scope)
        if (
            scope["method"] == "OPTIONS"
            and "access-control-request-private-network" in headers
        ):
            await self._handle_pna_preflight(scope, receive, send, headers)
            return

    # If we got here, apply the normal Starlette CORSMiddleware behavior
    await super().__call__(scope, receive, send)

core/interfaces/stream

inference.core.interfaces.stream.sinks

Classes

UDPSink

Source code in inference/core/interfaces/stream/sinks.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class UDPSink:
    @classmethod
    def init(cls, ip_address: str, port: int) -> "UDPSink":
        """
        Creates `InferencePipeline` predictions sink capable of sending model predictions over network
        using UDP socket.

        As an `inference` user, please use .init() method instead of constructor to instantiate objects.
        Args:
            ip_address (str): IP address to send predictions
            port (int): Port to send predictions

        Returns: Initialised object of `UDPSink` class.
        """
        udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
        udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
        udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
        udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
        return cls(
            ip_address=ip_address,
            port=port,
            udp_socket=udp_socket,
        )

    def __init__(self, ip_address: str, port: int, udp_socket: socket.socket):
        self._ip_address = ip_address
        self._port = port
        self._socket = udp_socket

    def send_predictions(
        self,
        predictions: Union[dict, List[Optional[dict]]],
        video_frame: Union[VideoFrame, List[Optional[VideoFrame]]],
    ) -> None:
        """
        Method to send predictions via UDP socket. Useful in combination with `InferencePipeline` as
        a sink for predictions.

        Args:
            predictions (Union[dict, List[Optional[dict]]]): Roboflow predictions, the function support single prediction
                processing and batch processing since version `0.9.18`. Batch predictions elements are optional, but
                should occur at the same position as `video_frame` list. Order is expected to match with `video_frame`.
            video_frame (Union[VideoFrame, List[Optional[VideoFrame]]]): frame of video with its basic metadata emitted
                by `VideoSource` or list of frames from (it is possible for empty batch frames at corresponding positions
                to `predictions` list). Order is expected to match with `predictions`

        Returns: None
        Side effects: Sends serialised `predictions` and `video_frame` metadata via the UDP socket as
            JSON string. It adds key named "inference_metadata" into `predictions` dict (mutating its
            state). "inference_metadata" contain id of the frame, frame grabbing timestamp and message
            emission time in datetime iso format.

        Example:
            ```python
            import cv2
            from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
            from inference.core.interfaces.stream.sinks import UDPSink

            udp_sink = UDPSink.init(ip_address="127.0.0.1", port=9090)

            pipeline = InferencePipeline.init(
                 model_id="your-model/3",
                 video_reference="./some_file.mp4",
                 on_prediction=udp_sink.send_predictions,
            )
            pipeline.start()
            pipeline.join()
            ```
            `UDPSink` used in this way will emit predictions to receiver automatically.
        """
        video_frame = wrap_in_list(element=video_frame)
        predictions = wrap_in_list(element=predictions)
        for single_frame, frame_predictions in zip(video_frame, predictions):
            if single_frame is None:
                continue
            inference_metadata = {
                "source_id": single_frame.source_id,
                "frame_id": single_frame.frame_id,
                "frame_decoding_time": single_frame.frame_timestamp.isoformat(),
                "emission_time": datetime.now().isoformat(),
            }
            frame_predictions["inference_metadata"] = inference_metadata
            serialised_predictions = json.dumps(frame_predictions).encode("utf-8")
            self._socket.sendto(
                serialised_predictions,
                (
                    self._ip_address,
                    self._port,
                ),
            )
Functions
init classmethod
init(ip_address, port)

Creates InferencePipeline predictions sink capable of sending model predictions over network using UDP socket.

As an inference user, please use .init() method instead of constructor to instantiate objects. Args: ip_address (str): IP address to send predictions port (int): Port to send predictions

Returns: Initialised object of UDPSink class.

Source code in inference/core/interfaces/stream/sinks.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
@classmethod
def init(cls, ip_address: str, port: int) -> "UDPSink":
    """
    Creates `InferencePipeline` predictions sink capable of sending model predictions over network
    using UDP socket.

    As an `inference` user, please use .init() method instead of constructor to instantiate objects.
    Args:
        ip_address (str): IP address to send predictions
        port (int): Port to send predictions

    Returns: Initialised object of `UDPSink` class.
    """
    udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
    udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
    udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
    return cls(
        ip_address=ip_address,
        port=port,
        udp_socket=udp_socket,
    )
send_predictions
send_predictions(predictions, video_frame)

Method to send predictions via UDP socket. Useful in combination with InferencePipeline as a sink for predictions.

Parameters:

Name Type Description Default
predictions Union[dict, List[Optional[dict]]]

Roboflow predictions, the function support single prediction processing and batch processing since version 0.9.18. Batch predictions elements are optional, but should occur at the same position as video_frame list. Order is expected to match with video_frame.

required
video_frame Union[VideoFrame, List[Optional[VideoFrame]]]

frame of video with its basic metadata emitted by VideoSource or list of frames from (it is possible for empty batch frames at corresponding positions to predictions list). Order is expected to match with predictions

required

Side effects: Sends serialised predictions and video_frame metadata via the UDP socket as JSON string. It adds key named "inference_metadata" into predictions dict (mutating its state). "inference_metadata" contain id of the frame, frame grabbing timestamp and message emission time in datetime iso format.

Example

import cv2
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
from inference.core.interfaces.stream.sinks import UDPSink

udp_sink = UDPSink.init(ip_address="127.0.0.1", port=9090)

pipeline = InferencePipeline.init(
     model_id="your-model/3",
     video_reference="./some_file.mp4",
     on_prediction=udp_sink.send_predictions,
)
pipeline.start()
pipeline.join()
UDPSink used in this way will emit predictions to receiver automatically.

Source code in inference/core/interfaces/stream/sinks.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def send_predictions(
    self,
    predictions: Union[dict, List[Optional[dict]]],
    video_frame: Union[VideoFrame, List[Optional[VideoFrame]]],
) -> None:
    """
    Method to send predictions via UDP socket. Useful in combination with `InferencePipeline` as
    a sink for predictions.

    Args:
        predictions (Union[dict, List[Optional[dict]]]): Roboflow predictions, the function support single prediction
            processing and batch processing since version `0.9.18`. Batch predictions elements are optional, but
            should occur at the same position as `video_frame` list. Order is expected to match with `video_frame`.
        video_frame (Union[VideoFrame, List[Optional[VideoFrame]]]): frame of video with its basic metadata emitted
            by `VideoSource` or list of frames from (it is possible for empty batch frames at corresponding positions
            to `predictions` list). Order is expected to match with `predictions`

    Returns: None
    Side effects: Sends serialised `predictions` and `video_frame` metadata via the UDP socket as
        JSON string. It adds key named "inference_metadata" into `predictions` dict (mutating its
        state). "inference_metadata" contain id of the frame, frame grabbing timestamp and message
        emission time in datetime iso format.

    Example:
        ```python
        import cv2
        from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
        from inference.core.interfaces.stream.sinks import UDPSink

        udp_sink = UDPSink.init(ip_address="127.0.0.1", port=9090)

        pipeline = InferencePipeline.init(
             model_id="your-model/3",
             video_reference="./some_file.mp4",
             on_prediction=udp_sink.send_predictions,
        )
        pipeline.start()
        pipeline.join()
        ```
        `UDPSink` used in this way will emit predictions to receiver automatically.
    """
    video_frame = wrap_in_list(element=video_frame)
    predictions = wrap_in_list(element=predictions)
    for single_frame, frame_predictions in zip(video_frame, predictions):
        if single_frame is None:
            continue
        inference_metadata = {
            "source_id": single_frame.source_id,
            "frame_id": single_frame.frame_id,
            "frame_decoding_time": single_frame.frame_timestamp.isoformat(),
            "emission_time": datetime.now().isoformat(),
        }
        frame_predictions["inference_metadata"] = inference_metadata
        serialised_predictions = json.dumps(frame_predictions).encode("utf-8")
        self._socket.sendto(
            serialised_predictions,
            (
                self._ip_address,
                self._port,
            ),
        )

VideoFileSink

Source code in inference/core/interfaces/stream/sinks.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
class VideoFileSink:
    @classmethod
    def init(
        cls,
        video_file_name: str,
        annotator: Optional[Union[BaseAnnotator, List[BaseAnnotator]]] = None,
        display_size: Optional[Tuple[int, int]] = (1280, 720),
        fps_monitor: Optional[sv.FPSMonitor] = DEFAULT_FPS_MONITOR,
        display_statistics: bool = False,
        output_fps: int = 25,
        quiet: bool = False,
        video_frame_size: Tuple[int, int] = (1280, 720),
    ) -> "VideoFileSink":
        """
        Creates `InferencePipeline` predictions sink capable of saving model predictions into video file.
        It works both for pipelines with single input video and multiple ones.

        As an `inference` user, please use .init() method instead of constructor to instantiate objects.
        Args:
            video_file_name (str): name of the video file to save predictions
            annotator (Union[BaseAnnotator, List[BaseAnnotator]]): instance of class inheriting from supervision BaseAnnotator
                or list of such instances. If nothing is passed chain of `sv.BoxAnnotator()` and `sv.LabelAnnotator()` is used.
            display_size (Tuple[int, int]): tuple in format (width, height) to resize visualisation output. Should
                be set to the same value as `display_size` for InferencePipeline with single video source, otherwise
                it represents the size of single visualisation tile (whole tiles mosaic will be scaled to
                `video_frame_size`)
            fps_monitor (Optional[sv.FPSMonitor]): FPS monitor used to monitor throughput
            display_statistics (bool): Flag to decide if throughput and latency can be displayed in the result image,
                if enabled, throughput will only be presented if `fps_monitor` is not None
            output_fps (int): desired FPS of output file
            quiet (bool): Flag to decide whether to log progress
            video_frame_size (Tuple[int, int]): The size of frame in target video file.

        Attributes:
            on_prediction (Callable[[dict, VideoFrame], None]): callable to be used as a sink for predictions

        Returns: Initialized object of `VideoFileSink` class.

        Example:
            ```python
            import cv2
            from inference import InferencePipeline
            from inference.core.interfaces.stream.sinks import VideoFileSink

            video_sink = VideoFileSink.init(video_file_name="output.avi")

            pipeline = InferencePipeline.init(
                model_id="your-model/3",
                video_reference="./some_file.mp4",
                on_prediction=video_sink.on_prediction,
            )
            pipeline.start()
            pipeline.join()
            video_sink.release()
            ```

            `VideoFileSink` used in this way will save predictions to video file automatically.
        """
        return cls(
            video_file_name=video_file_name,
            annotator=annotator,
            display_size=display_size,
            fps_monitor=fps_monitor,
            display_statistics=display_statistics,
            output_fps=output_fps,
            quiet=quiet,
            video_frame_size=video_frame_size,
        )

    def __init__(
        self,
        video_file_name: str,
        annotator: Union[BaseAnnotator, List[BaseAnnotator]],
        display_size: Optional[Tuple[int, int]],
        fps_monitor: Optional[sv.FPSMonitor],
        display_statistics: bool,
        output_fps: int,
        quiet: bool,
        video_frame_size: Tuple[int, int],
    ):
        self._video_file_name = video_file_name
        self._annotator = annotator
        self._display_size = display_size
        self._fps_monitor = fps_monitor
        self._display_statistics = display_statistics
        self._output_fps = output_fps
        self._quiet = quiet
        self._frame_idx = 0
        self._video_frame_size = video_frame_size
        self._video_writer: Optional[cv2.VideoWriter] = None
        self.on_prediction = partial(
            render_boxes,
            annotator=self._annotator,
            display_size=self._display_size,
            fps_monitor=self._fps_monitor,
            display_statistics=self._display_statistics,
            on_frame_rendered=self._save_predictions,
        )

    def release(self) -> None:
        """
        Releases VideoWriter object.
        """
        if self._video_writer is not None and self._video_writer.isOpened():
            self._video_writer.release()

    def _save_predictions(
        self,
        frame: Union[ImageWithSourceID, List[ImageWithSourceID]],
    ) -> None:
        if self._video_writer is None:
            self._initialise_sink()
        if issubclass(type(frame), list):
            frame = create_tiles(images=[i[1] for i in frame])
        else:
            frame = frame[1]
        if (frame.shape[1], frame.shape[0]) != self._video_frame_size:
            frame = letterbox_image(image=frame, desired_size=self._video_frame_size)
        self._video_writer.write(frame)
        if not self._quiet:
            print(f"Writing frame {self._frame_idx}", end="\r")
        self._frame_idx += 1

    def _initialise_sink(self) -> None:
        self._video_writer = cv2.VideoWriter(
            self._video_file_name,
            cv2.VideoWriter_fourcc(*"MJPG"),
            self._output_fps,
            self._video_frame_size,
        )

    def __enter__(self) -> "VideoFileSink":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.release()
Functions
init classmethod
init(
    video_file_name,
    annotator=None,
    display_size=(1280, 720),
    fps_monitor=DEFAULT_FPS_MONITOR,
    display_statistics=False,
    output_fps=25,
    quiet=False,
    video_frame_size=(1280, 720),
)

Creates InferencePipeline predictions sink capable of saving model predictions into video file. It works both for pipelines with single input video and multiple ones.

As an inference user, please use .init() method instead of constructor to instantiate objects. Args: video_file_name (str): name of the video file to save predictions annotator (Union[BaseAnnotator, List[BaseAnnotator]]): instance of class inheriting from supervision BaseAnnotator or list of such instances. If nothing is passed chain of sv.BoxAnnotator() and sv.LabelAnnotator() is used. display_size (Tuple[int, int]): tuple in format (width, height) to resize visualisation output. Should be set to the same value as display_size for InferencePipeline with single video source, otherwise it represents the size of single visualisation tile (whole tiles mosaic will be scaled to video_frame_size) fps_monitor (Optional[sv.FPSMonitor]): FPS monitor used to monitor throughput display_statistics (bool): Flag to decide if throughput and latency can be displayed in the result image, if enabled, throughput will only be presented if fps_monitor is not None output_fps (int): desired FPS of output file quiet (bool): Flag to decide whether to log progress video_frame_size (Tuple[int, int]): The size of frame in target video file.

Attributes:

Name Type Description
on_prediction Callable[[dict, VideoFrame], None]

callable to be used as a sink for predictions

Example
import cv2
from inference import InferencePipeline
from inference.core.interfaces.stream.sinks import VideoFileSink

video_sink = VideoFileSink.init(video_file_name="output.avi")

pipeline = InferencePipeline.init(
    model_id="your-model/3",
    video_reference="./some_file.mp4",
    on_prediction=video_sink.on_prediction,
)
pipeline.start()
pipeline.join()
video_sink.release()

VideoFileSink used in this way will save predictions to video file automatically.

Source code in inference/core/interfaces/stream/sinks.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
@classmethod
def init(
    cls,
    video_file_name: str,
    annotator: Optional[Union[BaseAnnotator, List[BaseAnnotator]]] = None,
    display_size: Optional[Tuple[int, int]] = (1280, 720),
    fps_monitor: Optional[sv.FPSMonitor] = DEFAULT_FPS_MONITOR,
    display_statistics: bool = False,
    output_fps: int = 25,
    quiet: bool = False,
    video_frame_size: Tuple[int, int] = (1280, 720),
) -> "VideoFileSink":
    """
    Creates `InferencePipeline` predictions sink capable of saving model predictions into video file.
    It works both for pipelines with single input video and multiple ones.

    As an `inference` user, please use .init() method instead of constructor to instantiate objects.
    Args:
        video_file_name (str): name of the video file to save predictions
        annotator (Union[BaseAnnotator, List[BaseAnnotator]]): instance of class inheriting from supervision BaseAnnotator
            or list of such instances. If nothing is passed chain of `sv.BoxAnnotator()` and `sv.LabelAnnotator()` is used.
        display_size (Tuple[int, int]): tuple in format (width, height) to resize visualisation output. Should
            be set to the same value as `display_size` for InferencePipeline with single video source, otherwise
            it represents the size of single visualisation tile (whole tiles mosaic will be scaled to
            `video_frame_size`)
        fps_monitor (Optional[sv.FPSMonitor]): FPS monitor used to monitor throughput
        display_statistics (bool): Flag to decide if throughput and latency can be displayed in the result image,
            if enabled, throughput will only be presented if `fps_monitor` is not None
        output_fps (int): desired FPS of output file
        quiet (bool): Flag to decide whether to log progress
        video_frame_size (Tuple[int, int]): The size of frame in target video file.

    Attributes:
        on_prediction (Callable[[dict, VideoFrame], None]): callable to be used as a sink for predictions

    Returns: Initialized object of `VideoFileSink` class.

    Example:
        ```python
        import cv2
        from inference import InferencePipeline
        from inference.core.interfaces.stream.sinks import VideoFileSink

        video_sink = VideoFileSink.init(video_file_name="output.avi")

        pipeline = InferencePipeline.init(
            model_id="your-model/3",
            video_reference="./some_file.mp4",
            on_prediction=video_sink.on_prediction,
        )
        pipeline.start()
        pipeline.join()
        video_sink.release()
        ```

        `VideoFileSink` used in this way will save predictions to video file automatically.
    """
    return cls(
        video_file_name=video_file_name,
        annotator=annotator,
        display_size=display_size,
        fps_monitor=fps_monitor,
        display_statistics=display_statistics,
        output_fps=output_fps,
        quiet=quiet,
        video_frame_size=video_frame_size,
    )
release
release()

Releases VideoWriter object.

Source code in inference/core/interfaces/stream/sinks.py
505
506
507
508
509
510
def release(self) -> None:
    """
    Releases VideoWriter object.
    """
    if self._video_writer is not None and self._video_writer.isOpened():
        self._video_writer.release()

Functions

active_learning_sink

active_learning_sink(
    predictions,
    video_frame,
    active_learning_middleware,
    model_type,
    disable_preproc_auto_orient=False,
)

Function to serve as Active Learning sink for InferencePipeline.

Parameters:

Name Type Description Default
predictions Union[dict, List[Optional[dict]]]

Roboflow predictions, the function support single prediction processing and batch processing since version 0.9.18. Batch predictions elements are optional, but should occur at the same position as video_frame list. Order is expected to match with video_frame.

required
video_frame Union[VideoFrame, List[Optional[VideoFrame]]]

frame of video with its basic metadata emitted by VideoSource or list of frames from (it is possible for empty batch frames at corresponding positions to predictions list). Order is expected to match with predictions

required
active_learning_middleware ActiveLearningMiddleware

instance of middleware to register data.

required
model_type str

Type of Roboflow model in use

required
disable_preproc_auto_orient bool

Flag to denote how image is preprocessed which is important in Active Learning.

False

Side effects: Can register data and predictions in Roboflow backend if that's the evaluation of sampling engine.

Source code in inference/core/interfaces/stream/sinks.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def active_learning_sink(
    predictions: Union[dict, List[Optional[dict]]],
    video_frame: Union[VideoFrame, List[Optional[VideoFrame]]],
    active_learning_middleware: ActiveLearningMiddleware,
    model_type: str,
    disable_preproc_auto_orient: bool = False,
) -> None:
    """
    Function to serve as Active Learning sink for InferencePipeline.

    Args:
        predictions (Union[dict, List[Optional[dict]]]): Roboflow predictions, the function support single prediction
            processing and batch processing since version `0.9.18`. Batch predictions elements are optional, but
            should occur at the same position as `video_frame` list. Order is expected to match with `video_frame`.
        video_frame (Union[VideoFrame, List[Optional[VideoFrame]]]): frame of video with its basic metadata emitted
            by `VideoSource` or list of frames from (it is possible for empty batch frames at corresponding positions
            to `predictions` list). Order is expected to match with `predictions`
        active_learning_middleware (ActiveLearningMiddleware): instance of middleware to register data.
        model_type (str): Type of Roboflow model in use
        disable_preproc_auto_orient (bool): Flag to denote how image is preprocessed which is important in
            Active Learning.

    Returns: None
    Side effects: Can register data and predictions in Roboflow backend if that's the evaluation of sampling engine.
    """
    video_frame = wrap_in_list(element=video_frame)
    predictions = wrap_in_list(element=predictions)
    images = [f.image for f in video_frame if f is not None]
    predictions = [p for p in predictions if p is not None]
    active_learning_middleware.register_batch(
        inference_inputs=images,
        predictions=predictions,
        prediction_type=model_type,
        disable_preproc_auto_orient=disable_preproc_auto_orient,
    )

multi_sink

multi_sink(predictions, video_frame, sinks)

Helper util useful to combine multiple sinks together, while using InferencePipeline.

Parameters:

Name Type Description Default
video_frame VideoFrame

frame of video with its basic metadata emitted by VideoSource

required
predictions dict

Roboflow object detection predictions with Bounding Boxes

required
sinks List[Callable[[VideoFrame, dict], None]]

list of sinks to be used. Each will be executed one-by-one in the order pointed in input list, all errors will be caught and reported via logger, without re-raising.

required

Side effects: Uses all sinks in context if (video_frame, predictions) input.

Example
from functools import partial
import cv2
from inference import InferencePipeline
from inference.core.interfaces.stream.sinks import UDPSink, render_boxes

udp_sink = UDPSink(ip_address="127.0.0.1", port=9090)
on_prediction = partial(multi_sink, sinks=[udp_sink.send_predictions, render_boxes])

pipeline = InferencePipeline.init(
    model_id="your-model/3",
    video_reference="./some_file.mp4",
    on_prediction=on_prediction,
)
pipeline.start()
pipeline.join()

As a result, predictions will both be sent via UDP socket and displayed in the screen.

Source code in inference/core/interfaces/stream/sinks.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def multi_sink(
    predictions: Union[dict, List[Optional[dict]]],
    video_frame: Union[VideoFrame, List[Optional[VideoFrame]]],
    sinks: List[SinkHandler],
) -> None:
    """
    Helper util useful to combine multiple sinks together, while using `InferencePipeline`.

    Args:
        video_frame (VideoFrame): frame of video with its basic metadata emitted by `VideoSource`
        predictions (dict): Roboflow object detection predictions with Bounding Boxes
        sinks (List[Callable[[VideoFrame, dict], None]]): list of sinks to be used. Each will be executed
            one-by-one in the order pointed in input list, all errors will be caught and reported via logger,
            without re-raising.

    Returns: None
    Side effects: Uses all sinks in context if (video_frame, predictions) input.

    Example:
        ```python
        from functools import partial
        import cv2
        from inference import InferencePipeline
        from inference.core.interfaces.stream.sinks import UDPSink, render_boxes

        udp_sink = UDPSink(ip_address="127.0.0.1", port=9090)
        on_prediction = partial(multi_sink, sinks=[udp_sink.send_predictions, render_boxes])

        pipeline = InferencePipeline.init(
            model_id="your-model/3",
            video_reference="./some_file.mp4",
            on_prediction=on_prediction,
        )
        pipeline.start()
        pipeline.join()
        ```

        As a result, predictions will both be sent via UDP socket and displayed in the screen.
    """
    for sink in sinks:
        try:
            sink(predictions, video_frame)
        except Exception as error:
            logger.error(
                f"Could not send prediction and/or frame to sink due to error: {error}."
            )

render_boxes

render_boxes(
    predictions,
    video_frame,
    annotator=None,
    display_size=(1280, 720),
    fps_monitor=DEFAULT_FPS_MONITOR,
    display_statistics=False,
    on_frame_rendered=display_image,
)

Helper tool to render object detection predictions on top of video frame. It is designed to be used with InferencePipeline, as sink for predictions. By default, it uses standard sv.BoxAnnotator() chained with sv.LabelAnnotator() to draw bounding boxes and resizes prediction to 1280x720 (keeping aspect ratio and adding black padding). One may configure default behaviour, for instance to display latency and throughput statistics. In batch mode it will display tiles of frames and overlay predictions.

This sink is only partially compatible with stubs and classification models (it will not fail, although predictions will not be displayed).

Since version 0.9.18, when multi-source InferencePipeline was introduced - it support batch input, without changes to old functionality when single (predictions, video_frame) is used.

Parameters:

Name Type Description Default
predictions Union[dict, List[Optional[dict]]]

Roboflow predictions, the function support single prediction processing and batch processing since version 0.9.18. Batch predictions elements are optional, but should occur at the same position as video_frame list. Order is expected to match with video_frame.

required
video_frame Union[VideoFrame, List[Optional[VideoFrame]]]

frame of video with its basic metadata emitted by VideoSource or list of frames from (it is possible for empty batch frames at corresponding positions to predictions list). Order is expected to match with predictions

required
annotator Union[BaseAnnotator, List[BaseAnnotator]]

instance of class inheriting from supervision BaseAnnotator or list of such instances. If nothing is passed chain of sv.BoxAnnotator() and sv.LabelAnnotator() is used.

None
display_size Tuple[int, int]

tuple in format (width, height) to resize visualisation output

(1280, 720)
fps_monitor Optional[FPSMonitor]

FPS monitor used to monitor throughput

DEFAULT_FPS_MONITOR
display_statistics bool

Flag to decide if throughput and latency can be displayed in the result image, if enabled, throughput will only be presented if fps_monitor is not None

False
on_frame_rendered Callable[[Union[ImageWithSourceID, List[ImageWithSourceID]]], None]

callback to be called once frame is rendered - by default, function will display OpenCV window. It expects optional integer identifier with np.ndarray or list of those elements. Identifier is supposed to refer to either source_id (for sequential input) or position in the batch (from 0 to batch_size-1).

display_image

Side effects: on_frame_rendered() is called against the tuple (stream_id, np.ndarray) produced from video frame and predictions.

Example
from functools import partial
import cv2
from inference import InferencePipeline
from inference.core.interfaces.stream.sinks import render_boxes

output_size = (640, 480)
video_sink = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*"MJPG"), 25.0, output_size)
on_prediction = partial(
    render_boxes,
    display_size=output_size,
    on_frame_rendered=lambda frame_data: video_sink.write(frame_data[1])
)

pipeline = InferencePipeline.init(
     model_id="your-model/3",
     video_reference="./some_file.mp4",
     on_prediction=on_prediction,
)
pipeline.start()
pipeline.join()
video_sink.release()

In this example, render_boxes() is used as a sink for InferencePipeline predictions - making frames with predictions displayed to be saved into video file. Please note that this is oversimplified example of usage which will not be robust against multiple streams - better implementation available in VideoFileSink class.

Source code in inference/core/interfaces/stream/sinks.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def render_boxes(
    predictions: Union[dict, List[Optional[dict]]],
    video_frame: Union[VideoFrame, List[Optional[VideoFrame]]],
    annotator: Union[BaseAnnotator, List[BaseAnnotator]] = None,
    display_size: Optional[Tuple[int, int]] = (1280, 720),
    fps_monitor: Optional[sv.FPSMonitor] = DEFAULT_FPS_MONITOR,
    display_statistics: bool = False,
    on_frame_rendered: Callable[
        [Union[ImageWithSourceID, List[ImageWithSourceID]]], None
    ] = display_image,
) -> None:
    """
    Helper tool to render object detection predictions on top of video frame. It is designed
    to be used with `InferencePipeline`, as sink for predictions. By default, it uses
    standard `sv.BoxAnnotator()` chained with `sv.LabelAnnotator()`
    to draw bounding boxes and resizes prediction to 1280x720 (keeping aspect ratio and adding black padding).
    One may configure default behaviour, for instance to display latency and throughput statistics.
    In batch mode it will display tiles of frames and overlay predictions.

    This sink is only partially compatible with stubs and classification models (it will not fail,
    although predictions will not be displayed).

    Since version `0.9.18`, when multi-source InferencePipeline was introduced - it support batch input, without
    changes to old functionality when single (predictions, video_frame) is used.

    Args:
        predictions (Union[dict, List[Optional[dict]]]): Roboflow predictions, the function support single prediction
            processing and batch processing since version `0.9.18`. Batch predictions elements are optional, but
            should occur at the same position as `video_frame` list. Order is expected to match with `video_frame`.
        video_frame (Union[VideoFrame, List[Optional[VideoFrame]]]): frame of video with its basic metadata emitted
            by `VideoSource` or list of frames from (it is possible for empty batch frames at corresponding positions
            to `predictions` list). Order is expected to match with `predictions`
        annotator (Union[BaseAnnotator, List[BaseAnnotator]]): instance of class inheriting from supervision BaseAnnotator
            or list of such instances. If nothing is passed chain of `sv.BoxAnnotator()` and `sv.LabelAnnotator()` is used.
        display_size (Tuple[int, int]): tuple in format (width, height) to resize visualisation output
        fps_monitor (Optional[sv.FPSMonitor]): FPS monitor used to monitor throughput
        display_statistics (bool): Flag to decide if throughput and latency can be displayed in the result image,
            if enabled, throughput will only be presented if `fps_monitor` is not None
        on_frame_rendered (Callable[[Union[ImageWithSourceID, List[ImageWithSourceID]]], None]): callback to be
            called once frame is rendered - by default, function will display OpenCV window. It expects optional integer
            identifier with np.ndarray or list of those elements. Identifier is supposed to refer to either source_id
            (for sequential input) or position in the batch (from 0 to batch_size-1).

    Returns: None
    Side effects: on_frame_rendered() is called against the tuple (stream_id, np.ndarray) produced from video
        frame and predictions.

    Example:
        ```python
        from functools import partial
        import cv2
        from inference import InferencePipeline
        from inference.core.interfaces.stream.sinks import render_boxes

        output_size = (640, 480)
        video_sink = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*"MJPG"), 25.0, output_size)
        on_prediction = partial(
            render_boxes,
            display_size=output_size,
            on_frame_rendered=lambda frame_data: video_sink.write(frame_data[1])
        )

        pipeline = InferencePipeline.init(
             model_id="your-model/3",
             video_reference="./some_file.mp4",
             on_prediction=on_prediction,
        )
        pipeline.start()
        pipeline.join()
        video_sink.release()
        ```

        In this example, `render_boxes()` is used as a sink for `InferencePipeline` predictions - making frames with
        predictions displayed to be saved into video file. Please note that this is oversimplified example of usage
        which will not be robust against multiple streams - better implementation available in `VideoFileSink` class.
    """
    sequential_input_provided = False
    if not isinstance(video_frame, list):
        sequential_input_provided = True
    video_frame = wrap_in_list(element=video_frame)
    predictions = wrap_in_list(element=predictions)
    if annotator is None:
        annotator = [
            DEFAULT_BBOX_ANNOTATOR,
            DEFAULT_LABEL_ANNOTATOR,
        ]
    fps_value = None
    if fps_monitor is not None:
        ticks = sum(f is not None for f in video_frame)
        for _ in range(ticks):
            fps_monitor.tick()
        if hasattr(fps_monitor, "fps"):
            fps_value = fps_monitor.fps
        else:
            fps_value = fps_monitor()
    images: List[ImageWithSourceID] = []
    annotators = annotator if isinstance(annotator, list) else [annotator]
    for idx, (single_frame, frame_prediction) in enumerate(
        zip(video_frame, predictions)
    ):
        image = _handle_frame_rendering(
            frame=single_frame,
            prediction=frame_prediction,
            annotators=annotators,
            display_size=display_size,
            display_statistics=display_statistics,
            fps_value=fps_value,
        )
        images.append((idx, image))
    if sequential_input_provided:
        on_frame_rendered((video_frame[0].source_id, images[0][1]))
    else:
        on_frame_rendered(images)

inference.core.interfaces.stream.stream

Classes

Stream

Bases: BaseInterface

Roboflow defined stream interface for a general-purpose inference server.

Attributes:

Name Type Description
model_manager ModelManager

The manager that handles model inference tasks.

model_registry RoboflowModelRegistry

The registry to fetch model instances.

api_key str

The API key for accessing models.

class_agnostic_nms bool

Flag for class-agnostic non-maximum suppression.

confidence float

Confidence threshold for inference.

iou_threshold float

The intersection-over-union threshold for detection.

json_response bool

Flag to toggle JSON response format.

max_candidates float

The maximum number of candidates for detection.

max_detections float

The maximum number of detections.

model str | Callable

The model to be used.

stream_id str

The ID of the stream to be used.

use_bytetrack bool

Flag to use bytetrack,

Methods:

Name Description
init_infer

Initialize the inference with a test frame.

preprocess_thread

Preprocess incoming frames for inference.

inference_request_thread

Manage the inference requests.

run_thread

Run the preprocessing and inference threads.

Source code in inference/core/interfaces/stream/stream.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class Stream(BaseInterface):
    """Roboflow defined stream interface for a general-purpose inference server.

    Attributes:
        model_manager (ModelManager): The manager that handles model inference tasks.
        model_registry (RoboflowModelRegistry): The registry to fetch model instances.
        api_key (str): The API key for accessing models.
        class_agnostic_nms (bool): Flag for class-agnostic non-maximum suppression.
        confidence (float): Confidence threshold for inference.
        iou_threshold (float): The intersection-over-union threshold for detection.
        json_response (bool): Flag to toggle JSON response format.
        max_candidates (float): The maximum number of candidates for detection.
        max_detections (float): The maximum number of detections.
        model (str|Callable): The model to be used.
        stream_id (str): The ID of the stream to be used.
        use_bytetrack (bool): Flag to use bytetrack,

    Methods:
        init_infer: Initialize the inference with a test frame.
        preprocess_thread: Preprocess incoming frames for inference.
        inference_request_thread: Manage the inference requests.
        run_thread: Run the preprocessing and inference threads.
    """

    def __init__(
        self,
        api_key: str = API_KEY,
        class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS,
        confidence: float = CONFIDENCE,
        enforce_fps: bool = ENFORCE_FPS,
        iou_threshold: float = IOU_THRESHOLD,
        max_candidates: float = MAX_CANDIDATES,
        max_detections: float = MAX_DETECTIONS,
        model: Union[str, Callable] = MODEL_ID,
        source: Union[int, str] = STREAM_ID,
        use_bytetrack: bool = ENABLE_BYTE_TRACK,
        use_main_thread: bool = False,
        output_channel_order: str = "RGB",
        on_prediction: Callable = None,
        on_start: Callable = None,
        on_stop: Callable = None,
    ):
        """Initialize the stream with the given parameters.
        Prints the server settings and initializes the inference with a test frame.
        """
        logger.info("Initializing server")

        self.frame_count = 0
        self.byte_tracker = sv.ByteTrack() if use_bytetrack else None
        self.use_bytetrack = use_bytetrack

        if source == "webcam":
            stream_id = 0
        else:
            stream_id = source

        self.stream_id = stream_id
        if self.stream_id is None:
            raise ValueError("STREAM_ID is not defined")
        self.model_id = model
        if not self.model_id:
            raise ValueError("MODEL_ID is not defined")
        self.api_key = api_key

        self.active_learning_middleware = NullActiveLearningMiddleware()
        if isinstance(model, str):
            self.model = get_model(model, self.api_key)
            if ACTIVE_LEARNING_ENABLED:
                self.active_learning_middleware = (
                    ThreadingActiveLearningMiddleware.init(
                        api_key=self.api_key,
                        model_id=self.model_id,
                        cache=cache,
                    )
                )
            self.task_type = get_model_type(
                model_id=self.model_id, api_key=self.api_key
            )[0]
        else:
            self.model = model
            self.task_type = "unknown"

        self.class_agnostic_nms = class_agnostic_nms
        self.confidence = confidence
        self.iou_threshold = iou_threshold
        self.max_candidates = max_candidates
        self.max_detections = max_detections
        self.use_main_thread = use_main_thread
        self.output_channel_order = output_channel_order

        self.inference_request_type = (
            inference.core.entities.requests.inference.ObjectDetectionInferenceRequest
        )

        self.webcam_stream = WebcamStream(
            stream_id=self.stream_id, enforce_fps=enforce_fps
        )
        logger.info(
            f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}"
        )

        self.on_start_callbacks = []
        self.on_stop_callbacks = [
            lambda: self.active_learning_middleware.stop_registration_thread()
        ]
        self.on_prediction_callbacks = []

        if on_prediction:
            self.on_prediction_callbacks.append(on_prediction)

        if on_start:
            self.on_start_callbacks.append(on_start)

        if on_stop:
            self.on_stop_callbacks.append(on_stop)

        self.init_infer()
        self.preproc_result = None
        self.inference_request_obj = None
        self.queue_control = False
        self.inference_response = None
        self.stop = False

        self.frame = None
        self.frame_cv = None
        self.frame_id = None
        logger.info("Server initialized with settings:")
        logger.info(f"Stream ID: {self.stream_id}")
        logger.info(f"Model ID: {self.model_id}")
        logger.info(f"Enforce FPS: {enforce_fps}")
        logger.info(f"Confidence: {self.confidence}")
        logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}")
        logger.info(f"IOU Threshold: {self.iou_threshold}")
        logger.info(f"Max Candidates: {self.max_candidates}")
        logger.info(f"Max Detections: {self.max_detections}")

        self.run_thread()

    def on_start(self, callback):
        self.on_start_callbacks.append(callback)

        unsubscribe = lambda: self.on_start_callbacks.remove(callback)
        return unsubscribe

    def on_stop(self, callback):
        self.on_stop_callbacks.append(callback)

        unsubscribe = lambda: self.on_stop_callbacks.remove(callback)
        return unsubscribe

    def on_prediction(self, callback):
        self.on_prediction_callbacks.append(callback)

        unsubscribe = lambda: self.on_prediction_callbacks.remove(callback)
        return unsubscribe

    def init_infer(self):
        """Initialize the inference with a test frame.

        Creates a test frame and runs it through the entire inference process to ensure everything is working.
        """
        frame = Image.new("RGB", (640, 640), color="black")
        self.model.infer(
            frame, confidence=self.confidence, iou_threshold=self.iou_threshold
        )
        self.active_learning_middleware.start_registration_thread()

    def preprocess_thread(self):
        """Preprocess incoming frames for inference.

        Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for
        inference.
        """
        webcam_stream = self.webcam_stream
        webcam_stream.start()
        # processing frames in input stream
        try:
            while True:
                if webcam_stream.stopped is True or self.stop:
                    break
                else:
                    self.frame_cv, frame_id = webcam_stream.read_opencv()
                    if frame_id > 0 and frame_id != self.frame_id:
                        self.frame_id = frame_id
                        self.frame = cv2.cvtColor(self.frame_cv, cv2.COLOR_BGR2RGB)
                        self.preproc_result = self.model.preprocess(self.frame_cv)
                        self.img_in, self.img_dims = self.preproc_result
                        self.queue_control = True

        except Exception as e:
            logger.exception(e)

    def inference_request_thread(self):
        """Manage the inference requests.

        Processes preprocessed frames for inference, post-processes the predictions, and sends the results
        to registered callbacks.
        """
        last_print = time.perf_counter()
        print_ind = 0
        while True:
            if self.webcam_stream.stopped is True or self.stop:
                while len(self.on_stop_callbacks) > 0:
                    # run each onStop callback only once from this thread
                    cb = self.on_stop_callbacks.pop()
                    cb()
                break
            if self.queue_control:
                while len(self.on_start_callbacks) > 0:
                    # run each onStart callback only once from this thread
                    cb = self.on_start_callbacks.pop()
                    cb()

                self.queue_control = False
                frame_id = self.frame_id
                inference_input = np.copy(self.frame_cv)
                start = time.perf_counter()
                predictions = self.model.predict(
                    self.img_in,
                )
                predictions = self.model.postprocess(
                    predictions,
                    self.img_dims,
                    class_agnostic_nms=self.class_agnostic_nms,
                    confidence=self.confidence,
                    iou_threshold=self.iou_threshold,
                    max_candidates=self.max_candidates,
                    max_detections=self.max_detections,
                )[0]

                self.active_learning_middleware.register(
                    inference_input=inference_input,
                    prediction=predictions.dict(by_alias=True, exclude_none=True),
                    prediction_type=self.task_type,
                )
                if self.use_bytetrack:
                    if hasattr(sv.Detections, "from_inference"):
                        detections = sv.Detections.from_inference(
                            predictions.dict(by_alias=True, exclude_none=True)
                        )
                    else:
                        detections = sv.Detections.from_inference(
                            predictions.dict(by_alias=True, exclude_none=True)
                        )
                    detections = self.byte_tracker.update_with_detections(detections)

                    if detections.tracker_id is None:
                        detections.tracker_id = np.array([], dtype=int)

                    for pred, detect in zip(predictions.predictions, detections):
                        pred.tracker_id = int(detect[4])
                predictions.frame_id = frame_id
                predictions = predictions.dict(by_alias=True, exclude_none=True)

                self.inference_response = predictions
                self.frame_count += 1

                for cb in self.on_prediction_callbacks:
                    if self.output_channel_order == "BGR":
                        cb(predictions, self.frame_cv)
                    else:
                        cb(predictions, np.asarray(self.frame))

                current = time.perf_counter()
                self.webcam_stream.max_fps = 1 / (current - start)
                logger.debug(f"FPS: {self.webcam_stream.max_fps:.2f}")

                if time.perf_counter() - last_print > 1:
                    print_ind = (print_ind + 1) % 4
                    last_print = time.perf_counter()

    def run_thread(self):
        """Run the preprocessing and inference threads.

        Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.
        """
        preprocess_thread = threading.Thread(target=self.preprocess_thread)
        preprocess_thread.start()

        if self.use_main_thread:
            self.inference_request_thread()
        else:
            # start a thread that looks for the predictions
            # and call the callbacks
            inference_request_thread = threading.Thread(
                target=self.inference_request_thread
            )
            inference_request_thread.start()
Functions
__init__
__init__(
    api_key=API_KEY,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    confidence=CONFIDENCE,
    enforce_fps=ENFORCE_FPS,
    iou_threshold=IOU_THRESHOLD,
    max_candidates=MAX_CANDIDATES,
    max_detections=MAX_DETECTIONS,
    model=MODEL_ID,
    source=STREAM_ID,
    use_bytetrack=ENABLE_BYTE_TRACK,
    use_main_thread=False,
    output_channel_order="RGB",
    on_prediction=None,
    on_start=None,
    on_stop=None,
)

Initialize the stream with the given parameters. Prints the server settings and initializes the inference with a test frame.

Source code in inference/core/interfaces/stream/stream.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def __init__(
    self,
    api_key: str = API_KEY,
    class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS,
    confidence: float = CONFIDENCE,
    enforce_fps: bool = ENFORCE_FPS,
    iou_threshold: float = IOU_THRESHOLD,
    max_candidates: float = MAX_CANDIDATES,
    max_detections: float = MAX_DETECTIONS,
    model: Union[str, Callable] = MODEL_ID,
    source: Union[int, str] = STREAM_ID,
    use_bytetrack: bool = ENABLE_BYTE_TRACK,
    use_main_thread: bool = False,
    output_channel_order: str = "RGB",
    on_prediction: Callable = None,
    on_start: Callable = None,
    on_stop: Callable = None,
):
    """Initialize the stream with the given parameters.
    Prints the server settings and initializes the inference with a test frame.
    """
    logger.info("Initializing server")

    self.frame_count = 0
    self.byte_tracker = sv.ByteTrack() if use_bytetrack else None
    self.use_bytetrack = use_bytetrack

    if source == "webcam":
        stream_id = 0
    else:
        stream_id = source

    self.stream_id = stream_id
    if self.stream_id is None:
        raise ValueError("STREAM_ID is not defined")
    self.model_id = model
    if not self.model_id:
        raise ValueError("MODEL_ID is not defined")
    self.api_key = api_key

    self.active_learning_middleware = NullActiveLearningMiddleware()
    if isinstance(model, str):
        self.model = get_model(model, self.api_key)
        if ACTIVE_LEARNING_ENABLED:
            self.active_learning_middleware = (
                ThreadingActiveLearningMiddleware.init(
                    api_key=self.api_key,
                    model_id=self.model_id,
                    cache=cache,
                )
            )
        self.task_type = get_model_type(
            model_id=self.model_id, api_key=self.api_key
        )[0]
    else:
        self.model = model
        self.task_type = "unknown"

    self.class_agnostic_nms = class_agnostic_nms
    self.confidence = confidence
    self.iou_threshold = iou_threshold
    self.max_candidates = max_candidates
    self.max_detections = max_detections
    self.use_main_thread = use_main_thread
    self.output_channel_order = output_channel_order

    self.inference_request_type = (
        inference.core.entities.requests.inference.ObjectDetectionInferenceRequest
    )

    self.webcam_stream = WebcamStream(
        stream_id=self.stream_id, enforce_fps=enforce_fps
    )
    logger.info(
        f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}"
    )

    self.on_start_callbacks = []
    self.on_stop_callbacks = [
        lambda: self.active_learning_middleware.stop_registration_thread()
    ]
    self.on_prediction_callbacks = []

    if on_prediction:
        self.on_prediction_callbacks.append(on_prediction)

    if on_start:
        self.on_start_callbacks.append(on_start)

    if on_stop:
        self.on_stop_callbacks.append(on_stop)

    self.init_infer()
    self.preproc_result = None
    self.inference_request_obj = None
    self.queue_control = False
    self.inference_response = None
    self.stop = False

    self.frame = None
    self.frame_cv = None
    self.frame_id = None
    logger.info("Server initialized with settings:")
    logger.info(f"Stream ID: {self.stream_id}")
    logger.info(f"Model ID: {self.model_id}")
    logger.info(f"Enforce FPS: {enforce_fps}")
    logger.info(f"Confidence: {self.confidence}")
    logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}")
    logger.info(f"IOU Threshold: {self.iou_threshold}")
    logger.info(f"Max Candidates: {self.max_candidates}")
    logger.info(f"Max Detections: {self.max_detections}")

    self.run_thread()
inference_request_thread
inference_request_thread()

Manage the inference requests.

Processes preprocessed frames for inference, post-processes the predictions, and sends the results to registered callbacks.

Source code in inference/core/interfaces/stream/stream.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def inference_request_thread(self):
    """Manage the inference requests.

    Processes preprocessed frames for inference, post-processes the predictions, and sends the results
    to registered callbacks.
    """
    last_print = time.perf_counter()
    print_ind = 0
    while True:
        if self.webcam_stream.stopped is True or self.stop:
            while len(self.on_stop_callbacks) > 0:
                # run each onStop callback only once from this thread
                cb = self.on_stop_callbacks.pop()
                cb()
            break
        if self.queue_control:
            while len(self.on_start_callbacks) > 0:
                # run each onStart callback only once from this thread
                cb = self.on_start_callbacks.pop()
                cb()

            self.queue_control = False
            frame_id = self.frame_id
            inference_input = np.copy(self.frame_cv)
            start = time.perf_counter()
            predictions = self.model.predict(
                self.img_in,
            )
            predictions = self.model.postprocess(
                predictions,
                self.img_dims,
                class_agnostic_nms=self.class_agnostic_nms,
                confidence=self.confidence,
                iou_threshold=self.iou_threshold,
                max_candidates=self.max_candidates,
                max_detections=self.max_detections,
            )[0]

            self.active_learning_middleware.register(
                inference_input=inference_input,
                prediction=predictions.dict(by_alias=True, exclude_none=True),
                prediction_type=self.task_type,
            )
            if self.use_bytetrack:
                if hasattr(sv.Detections, "from_inference"):
                    detections = sv.Detections.from_inference(
                        predictions.dict(by_alias=True, exclude_none=True)
                    )
                else:
                    detections = sv.Detections.from_inference(
                        predictions.dict(by_alias=True, exclude_none=True)
                    )
                detections = self.byte_tracker.update_with_detections(detections)

                if detections.tracker_id is None:
                    detections.tracker_id = np.array([], dtype=int)

                for pred, detect in zip(predictions.predictions, detections):
                    pred.tracker_id = int(detect[4])
            predictions.frame_id = frame_id
            predictions = predictions.dict(by_alias=True, exclude_none=True)

            self.inference_response = predictions
            self.frame_count += 1

            for cb in self.on_prediction_callbacks:
                if self.output_channel_order == "BGR":
                    cb(predictions, self.frame_cv)
                else:
                    cb(predictions, np.asarray(self.frame))

            current = time.perf_counter()
            self.webcam_stream.max_fps = 1 / (current - start)
            logger.debug(f"FPS: {self.webcam_stream.max_fps:.2f}")

            if time.perf_counter() - last_print > 1:
                print_ind = (print_ind + 1) % 4
                last_print = time.perf_counter()
init_infer
init_infer()

Initialize the inference with a test frame.

Creates a test frame and runs it through the entire inference process to ensure everything is working.

Source code in inference/core/interfaces/stream/stream.py
196
197
198
199
200
201
202
203
204
205
def init_infer(self):
    """Initialize the inference with a test frame.

    Creates a test frame and runs it through the entire inference process to ensure everything is working.
    """
    frame = Image.new("RGB", (640, 640), color="black")
    self.model.infer(
        frame, confidence=self.confidence, iou_threshold=self.iou_threshold
    )
    self.active_learning_middleware.start_registration_thread()
preprocess_thread
preprocess_thread()

Preprocess incoming frames for inference.

Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for inference.

Source code in inference/core/interfaces/stream/stream.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def preprocess_thread(self):
    """Preprocess incoming frames for inference.

    Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for
    inference.
    """
    webcam_stream = self.webcam_stream
    webcam_stream.start()
    # processing frames in input stream
    try:
        while True:
            if webcam_stream.stopped is True or self.stop:
                break
            else:
                self.frame_cv, frame_id = webcam_stream.read_opencv()
                if frame_id > 0 and frame_id != self.frame_id:
                    self.frame_id = frame_id
                    self.frame = cv2.cvtColor(self.frame_cv, cv2.COLOR_BGR2RGB)
                    self.preproc_result = self.model.preprocess(self.frame_cv)
                    self.img_in, self.img_dims = self.preproc_result
                    self.queue_control = True

    except Exception as e:
        logger.exception(e)
run_thread
run_thread()

Run the preprocessing and inference threads.

Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.

Source code in inference/core/interfaces/stream/stream.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def run_thread(self):
    """Run the preprocessing and inference threads.

    Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.
    """
    preprocess_thread = threading.Thread(target=self.preprocess_thread)
    preprocess_thread.start()

    if self.use_main_thread:
        self.inference_request_thread()
    else:
        # start a thread that looks for the predictions
        # and call the callbacks
        inference_request_thread = threading.Thread(
            target=self.inference_request_thread
        )
        inference_request_thread.start()

Functions

inference.core.interfaces.stream.watchdog

This module contains component intended to use in combination with InferencePipeline to ensure observability. Please consider them internal details of implementation.

Classes

BasePipelineWatchDog

Bases: PipelineWatchDog

Implementation to be used from single inference thread, as it keeps state assumed to represent status of consecutive stage of prediction process in latency monitor.

Source code in inference/core/interfaces/stream/watchdog.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class BasePipelineWatchDog(PipelineWatchDog):
    """
    Implementation to be used from single inference thread, as it keeps
    state assumed to represent status of consecutive stage of prediction process
    in latency monitor.
    """

    def __init__(self):
        super().__init__()
        self._video_sources: Optional[List[VideoSource]] = None
        self._inference_throughput_monitor = sv.FPSMonitor()
        self._latency_monitors: Dict[Optional[int], LatencyMonitor] = {}
        self._stream_updates = deque(maxlen=MAX_UPDATES_CONTEXT)

    def register_video_sources(self, video_sources: List[VideoSource]) -> None:
        self._video_sources = video_sources
        for source in video_sources:
            self._latency_monitors[source.source_id] = LatencyMonitor(
                source_id=source.source_id
            )

    def on_status_update(self, status_update: StatusUpdate) -> None:
        if status_update.severity.value <= UpdateSeverity.DEBUG.value:
            return None
        self._stream_updates.append(status_update)

    def on_model_inference_started(self, frames: List[VideoFrame]) -> None:
        for frame in frames:
            self._latency_monitors[frame.source_id].register_inference_start(
                frame_timestamp=frame.frame_timestamp,
                frame_id=frame.frame_id,
            )

    def on_model_prediction_ready(self, frames: List[VideoFrame]) -> None:
        for frame in frames:
            self._latency_monitors[frame.source_id].register_prediction_ready(
                frame_timestamp=frame.frame_timestamp,
                frame_id=frame.frame_id,
            )
            self._inference_throughput_monitor.tick()

    def get_report(self) -> PipelineStateReport:
        sources_metadata = []
        if self._video_sources is not None:
            sources_metadata = [s.describe_source() for s in self._video_sources]
        latency_reports = [
            monitor.summarise_reports() for monitor in self._latency_monitors.values()
        ]
        if hasattr(self._inference_throughput_monitor, "fps"):
            _inference_throughput_fps = self._inference_throughput_monitor.fps
        else:
            _inference_throughput_fps = self._inference_throughput_monitor()
        return PipelineStateReport(
            video_source_status_updates=list(self._stream_updates),
            latency_reports=latency_reports,
            inference_throughput=_inference_throughput_fps,
            sources_metadata=sources_metadata,
        )

core/interfaces/udp

inference.core.interfaces.udp.udp_stream

Classes

UdpStream

Bases: BaseInterface

Roboflow defined UDP interface for a general-purpose inference server.

Attributes:

Name Type Description
model_manager ModelManager

The manager that handles model inference tasks.

model_registry RoboflowModelRegistry

The registry to fetch model instances.

api_key str

The API key for accessing models.

class_agnostic_nms bool

Flag for class-agnostic non-maximum suppression.

confidence float

Confidence threshold for inference.

ip_broadcast_addr str

The IP address to broadcast to.

ip_broadcast_port int

The port to broadcast on.

iou_threshold float

The intersection-over-union threshold for detection.

max_candidates float

The maximum number of candidates for detection.

max_detections float

The maximum number of detections.

model_id str

The ID of the model to be used.

stream_id str

The ID of the stream to be used.

use_bytetrack bool

Flag to use bytetrack,

Methods:

Name Description
init_infer

Initialize the inference with a test frame.

preprocess_thread

Preprocess incoming frames for inference.

inference_request_thread

Manage the inference requests.

run_thread

Run the preprocessing and inference threads.

Source code in inference/core/interfaces/udp/udp_stream.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class UdpStream(BaseInterface):
    """Roboflow defined UDP interface for a general-purpose inference server.

    Attributes:
        model_manager (ModelManager): The manager that handles model inference tasks.
        model_registry (RoboflowModelRegistry): The registry to fetch model instances.
        api_key (str): The API key for accessing models.
        class_agnostic_nms (bool): Flag for class-agnostic non-maximum suppression.
        confidence (float): Confidence threshold for inference.
        ip_broadcast_addr (str): The IP address to broadcast to.
        ip_broadcast_port (int): The port to broadcast on.
        iou_threshold (float): The intersection-over-union threshold for detection.
        max_candidates (float): The maximum number of candidates for detection.
        max_detections (float): The maximum number of detections.
        model_id (str): The ID of the model to be used.
        stream_id (str): The ID of the stream to be used.
        use_bytetrack (bool): Flag to use bytetrack,

    Methods:
        init_infer: Initialize the inference with a test frame.
        preprocess_thread: Preprocess incoming frames for inference.
        inference_request_thread: Manage the inference requests.
        run_thread: Run the preprocessing and inference threads.
    """

    def __init__(
        self,
        api_key: str = API_KEY,
        class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS,
        confidence: float = CONFIDENCE,
        enforce_fps: bool = ENFORCE_FPS,
        ip_broadcast_addr: str = IP_BROADCAST_ADDR,
        ip_broadcast_port: int = IP_BROADCAST_PORT,
        iou_threshold: float = IOU_THRESHOLD,
        max_candidates: float = MAX_CANDIDATES,
        max_detections: float = MAX_DETECTIONS,
        model_id: str = MODEL_ID,
        stream_id: Union[int, str] = STREAM_ID,
        use_bytetrack: bool = ENABLE_BYTE_TRACK,
    ):
        """Initialize the UDP stream with the given parameters.
        Prints the server settings and initializes the inference with a test frame.
        """
        logger.info("Initializing server")

        self.frame_count = 0
        self.byte_tracker = sv.ByteTrack() if use_bytetrack else None
        self.use_bytetrack = use_bytetrack

        self.stream_id = stream_id
        if self.stream_id is None:
            raise ValueError("STREAM_ID is not defined")
        self.model_id = model_id
        if not self.model_id:
            raise ValueError("MODEL_ID is not defined")
        self.api_key = api_key
        if not self.api_key:
            raise ValueError(
                f"API key is missing. Either pass it explicitly to constructor, or use one of env variables: "
                f"{API_KEY_ENV_NAMES}. Visit "
                f"https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to generate "
                f"the key."
            )

        self.model = get_model(self.model_id, self.api_key)
        self.task_type = get_model_type(model_id=self.model_id, api_key=self.api_key)[0]
        self.active_learning_middleware = NullActiveLearningMiddleware()
        if ACTIVE_LEARNING_ENABLED:
            self.active_learning_middleware = ThreadingActiveLearningMiddleware.init(
                api_key=self.api_key,
                model_id=self.model_id,
                cache=cache,
            )
        self.class_agnostic_nms = class_agnostic_nms
        self.confidence = confidence
        self.iou_threshold = iou_threshold
        self.max_candidates = max_candidates
        self.max_detections = max_detections
        self.ip_broadcast_addr = ip_broadcast_addr
        self.ip_broadcast_port = ip_broadcast_port

        self.inference_request_type = (
            inference.core.entities.requests.inference.ObjectDetectionInferenceRequest
        )

        self.UDPServerSocket = socket.socket(
            family=socket.AF_INET, type=socket.SOCK_DGRAM
        )
        self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
        self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
        self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)

        self.webcam_stream = WebcamStream(
            stream_id=self.stream_id, enforce_fps=enforce_fps
        )
        logger.info(
            f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}"
        )

        self.init_infer()
        self.preproc_result = None
        self.inference_request_obj = None
        self.queue_control = False
        self.inference_response = None
        self.stop = False

        self.frame_cv = None
        self.frame_id = None
        logger.info("Server initialized with settings:")
        logger.info(f"Stream ID: {self.stream_id}")
        logger.info(f"Model ID: {self.model_id}")
        logger.info(f"Confidence: {self.confidence}")
        logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}")
        logger.info(f"IOU Threshold: {self.iou_threshold}")
        logger.info(f"Max Candidates: {self.max_candidates}")
        logger.info(f"Max Detections: {self.max_detections}")

    def init_infer(self):
        """Initialize the inference with a test frame.

        Creates a test frame and runs it through the entire inference process to ensure everything is working.
        """
        frame = Image.new("RGB", (640, 640), color="black")
        self.model.infer(
            frame, confidence=self.confidence, iou_threshold=self.iou_threshold
        )
        self.active_learning_middleware.start_registration_thread()

    def preprocess_thread(self):
        """Preprocess incoming frames for inference.

        Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for
        inference.
        """
        webcam_stream = self.webcam_stream
        webcam_stream.start()
        # processing frames in input stream
        try:
            while True:
                if webcam_stream.stopped is True or self.stop:
                    break
                else:
                    self.frame_cv, frame_id = webcam_stream.read_opencv()
                    if frame_id != self.frame_id:
                        self.frame_id = frame_id
                        self.preproc_result = self.model.preprocess(self.frame_cv)
                        self.img_in, self.img_dims = self.preproc_result
                        self.queue_control = True

        except Exception as e:
            logger.error(e)

    def inference_request_thread(self):
        """Manage the inference requests.

        Processes preprocessed frames for inference, post-processes the predictions, and sends the results
        as a UDP broadcast.
        """
        last_print = time.perf_counter()
        print_ind = 0
        print_chars = ["|", "/", "-", "\\"]
        while True:
            if self.stop:
                break
            if self.queue_control:
                self.queue_control = False
                frame_id = self.frame_id
                inference_input = np.copy(self.frame_cv)
                predictions = self.model.predict(
                    self.img_in,
                )
                predictions = self.model.postprocess(
                    predictions,
                    self.img_dims,
                    class_agnostic_nms=self.class_agnostic_nms,
                    confidence=self.confidence,
                    iou_threshold=self.iou_threshold,
                    max_candidates=self.max_candidates,
                    max_detections=self.max_detections,
                )[0]
                self.active_learning_middleware.register(
                    inference_input=inference_input,
                    prediction=predictions.dict(by_alias=True, exclude_none=True),
                    prediction_type=self.task_type,
                )
                if self.use_bytetrack:
                    if hasattr(sv.Detections, "from_inference"):
                        detections = sv.Detections.from_inference(
                            predictions.dict(by_alias=True), self.model.class_names
                        )
                    else:
                        detections = sv.Detections.from_inference(
                            predictions.dict(by_alias=True), self.model.class_names
                        )
                    detections = self.byte_tracker.update_with_detections(detections)
                    for pred, detect in zip(predictions.predictions, detections):
                        pred.tracker_id = int(detect[4])
                predictions.frame_id = frame_id
                predictions = predictions.json(exclude_none=True, by_alias=True)

                self.inference_response = predictions
                self.frame_count += 1

                bytesToSend = predictions.encode("utf-8")
                self.UDPServerSocket.sendto(
                    bytesToSend,
                    (
                        self.ip_broadcast_addr,
                        self.ip_broadcast_port,
                    ),
                )
                if time.perf_counter() - last_print > 1:
                    print(f"Streaming {print_chars[print_ind]}", end="\r")
                    print_ind = (print_ind + 1) % 4
                    last_print = time.perf_counter()

    def run_thread(self):
        """Run the preprocessing and inference threads.

        Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.
        """
        preprocess_thread = threading.Thread(target=self.preprocess_thread)
        inference_request_thread = threading.Thread(
            target=self.inference_request_thread
        )

        preprocess_thread.start()
        inference_request_thread.start()

        while True:
            try:
                time.sleep(10)
            except KeyboardInterrupt:
                logger.info("Stopping server...")
                self.stop = True
                self.active_learning_middleware.stop_registration_thread()
                time.sleep(3)
                sys.exit(0)
Functions
__init__
__init__(
    api_key=API_KEY,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    confidence=CONFIDENCE,
    enforce_fps=ENFORCE_FPS,
    ip_broadcast_addr=IP_BROADCAST_ADDR,
    ip_broadcast_port=IP_BROADCAST_PORT,
    iou_threshold=IOU_THRESHOLD,
    max_candidates=MAX_CANDIDATES,
    max_detections=MAX_DETECTIONS,
    model_id=MODEL_ID,
    stream_id=STREAM_ID,
    use_bytetrack=ENABLE_BYTE_TRACK,
)

Initialize the UDP stream with the given parameters. Prints the server settings and initializes the inference with a test frame.

Source code in inference/core/interfaces/udp/udp_stream.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def __init__(
    self,
    api_key: str = API_KEY,
    class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS,
    confidence: float = CONFIDENCE,
    enforce_fps: bool = ENFORCE_FPS,
    ip_broadcast_addr: str = IP_BROADCAST_ADDR,
    ip_broadcast_port: int = IP_BROADCAST_PORT,
    iou_threshold: float = IOU_THRESHOLD,
    max_candidates: float = MAX_CANDIDATES,
    max_detections: float = MAX_DETECTIONS,
    model_id: str = MODEL_ID,
    stream_id: Union[int, str] = STREAM_ID,
    use_bytetrack: bool = ENABLE_BYTE_TRACK,
):
    """Initialize the UDP stream with the given parameters.
    Prints the server settings and initializes the inference with a test frame.
    """
    logger.info("Initializing server")

    self.frame_count = 0
    self.byte_tracker = sv.ByteTrack() if use_bytetrack else None
    self.use_bytetrack = use_bytetrack

    self.stream_id = stream_id
    if self.stream_id is None:
        raise ValueError("STREAM_ID is not defined")
    self.model_id = model_id
    if not self.model_id:
        raise ValueError("MODEL_ID is not defined")
    self.api_key = api_key
    if not self.api_key:
        raise ValueError(
            f"API key is missing. Either pass it explicitly to constructor, or use one of env variables: "
            f"{API_KEY_ENV_NAMES}. Visit "
            f"https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to generate "
            f"the key."
        )

    self.model = get_model(self.model_id, self.api_key)
    self.task_type = get_model_type(model_id=self.model_id, api_key=self.api_key)[0]
    self.active_learning_middleware = NullActiveLearningMiddleware()
    if ACTIVE_LEARNING_ENABLED:
        self.active_learning_middleware = ThreadingActiveLearningMiddleware.init(
            api_key=self.api_key,
            model_id=self.model_id,
            cache=cache,
        )
    self.class_agnostic_nms = class_agnostic_nms
    self.confidence = confidence
    self.iou_threshold = iou_threshold
    self.max_candidates = max_candidates
    self.max_detections = max_detections
    self.ip_broadcast_addr = ip_broadcast_addr
    self.ip_broadcast_port = ip_broadcast_port

    self.inference_request_type = (
        inference.core.entities.requests.inference.ObjectDetectionInferenceRequest
    )

    self.UDPServerSocket = socket.socket(
        family=socket.AF_INET, type=socket.SOCK_DGRAM
    )
    self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
    self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)

    self.webcam_stream = WebcamStream(
        stream_id=self.stream_id, enforce_fps=enforce_fps
    )
    logger.info(
        f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}"
    )

    self.init_infer()
    self.preproc_result = None
    self.inference_request_obj = None
    self.queue_control = False
    self.inference_response = None
    self.stop = False

    self.frame_cv = None
    self.frame_id = None
    logger.info("Server initialized with settings:")
    logger.info(f"Stream ID: {self.stream_id}")
    logger.info(f"Model ID: {self.model_id}")
    logger.info(f"Confidence: {self.confidence}")
    logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}")
    logger.info(f"IOU Threshold: {self.iou_threshold}")
    logger.info(f"Max Candidates: {self.max_candidates}")
    logger.info(f"Max Detections: {self.max_detections}")
inference_request_thread
inference_request_thread()

Manage the inference requests.

Processes preprocessed frames for inference, post-processes the predictions, and sends the results as a UDP broadcast.

Source code in inference/core/interfaces/udp/udp_stream.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def inference_request_thread(self):
    """Manage the inference requests.

    Processes preprocessed frames for inference, post-processes the predictions, and sends the results
    as a UDP broadcast.
    """
    last_print = time.perf_counter()
    print_ind = 0
    print_chars = ["|", "/", "-", "\\"]
    while True:
        if self.stop:
            break
        if self.queue_control:
            self.queue_control = False
            frame_id = self.frame_id
            inference_input = np.copy(self.frame_cv)
            predictions = self.model.predict(
                self.img_in,
            )
            predictions = self.model.postprocess(
                predictions,
                self.img_dims,
                class_agnostic_nms=self.class_agnostic_nms,
                confidence=self.confidence,
                iou_threshold=self.iou_threshold,
                max_candidates=self.max_candidates,
                max_detections=self.max_detections,
            )[0]
            self.active_learning_middleware.register(
                inference_input=inference_input,
                prediction=predictions.dict(by_alias=True, exclude_none=True),
                prediction_type=self.task_type,
            )
            if self.use_bytetrack:
                if hasattr(sv.Detections, "from_inference"):
                    detections = sv.Detections.from_inference(
                        predictions.dict(by_alias=True), self.model.class_names
                    )
                else:
                    detections = sv.Detections.from_inference(
                        predictions.dict(by_alias=True), self.model.class_names
                    )
                detections = self.byte_tracker.update_with_detections(detections)
                for pred, detect in zip(predictions.predictions, detections):
                    pred.tracker_id = int(detect[4])
            predictions.frame_id = frame_id
            predictions = predictions.json(exclude_none=True, by_alias=True)

            self.inference_response = predictions
            self.frame_count += 1

            bytesToSend = predictions.encode("utf-8")
            self.UDPServerSocket.sendto(
                bytesToSend,
                (
                    self.ip_broadcast_addr,
                    self.ip_broadcast_port,
                ),
            )
            if time.perf_counter() - last_print > 1:
                print(f"Streaming {print_chars[print_ind]}", end="\r")
                print_ind = (print_ind + 1) % 4
                last_print = time.perf_counter()
init_infer
init_infer()

Initialize the inference with a test frame.

Creates a test frame and runs it through the entire inference process to ensure everything is working.

Source code in inference/core/interfaces/udp/udp_stream.py
161
162
163
164
165
166
167
168
169
170
def init_infer(self):
    """Initialize the inference with a test frame.

    Creates a test frame and runs it through the entire inference process to ensure everything is working.
    """
    frame = Image.new("RGB", (640, 640), color="black")
    self.model.infer(
        frame, confidence=self.confidence, iou_threshold=self.iou_threshold
    )
    self.active_learning_middleware.start_registration_thread()
preprocess_thread
preprocess_thread()

Preprocess incoming frames for inference.

Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for inference.

Source code in inference/core/interfaces/udp/udp_stream.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def preprocess_thread(self):
    """Preprocess incoming frames for inference.

    Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for
    inference.
    """
    webcam_stream = self.webcam_stream
    webcam_stream.start()
    # processing frames in input stream
    try:
        while True:
            if webcam_stream.stopped is True or self.stop:
                break
            else:
                self.frame_cv, frame_id = webcam_stream.read_opencv()
                if frame_id != self.frame_id:
                    self.frame_id = frame_id
                    self.preproc_result = self.model.preprocess(self.frame_cv)
                    self.img_in, self.img_dims = self.preproc_result
                    self.queue_control = True

    except Exception as e:
        logger.error(e)
run_thread
run_thread()

Run the preprocessing and inference threads.

Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.

Source code in inference/core/interfaces/udp/udp_stream.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def run_thread(self):
    """Run the preprocessing and inference threads.

    Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt.
    """
    preprocess_thread = threading.Thread(target=self.preprocess_thread)
    inference_request_thread = threading.Thread(
        target=self.inference_request_thread
    )

    preprocess_thread.start()
    inference_request_thread.start()

    while True:
        try:
            time.sleep(10)
        except KeyboardInterrupt:
            logger.info("Stopping server...")
            self.stop = True
            self.active_learning_middleware.stop_registration_thread()
            time.sleep(3)
            sys.exit(0)

Functions

core/interfaces/webrtc_worker

inference.core.interfaces.webrtc_worker.entities

Classes

VideoFileUploadState

Bases: str, Enum

State of video file upload.

Source code in inference/core/interfaces/webrtc_worker/entities.py
106
107
108
109
110
111
112
113
class VideoFileUploadState(str, Enum):
    """State of video file upload."""

    IDLE = "idle"
    UPLOADING = "uploading"
    COMPLETE = "complete"
    PROCESSING = "processing"
    ERROR = "error"

WebRTCOutput

Bases: BaseModel

Output sent via WebRTC data channel.

serialized_output_data contains a dictionary with workflow outputs: - If data_output is None or []: no data sent (only metadata) - If data_output is ["*"]: all workflow outputs (excluding images, unless explicitly named) - If data_output is ["field1", "field2"]: only those fields (including images if explicitly named)

Source code in inference/core/interfaces/webrtc_worker/entities.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class WebRTCOutput(BaseModel):
    """Output sent via WebRTC data channel.

    serialized_output_data contains a dictionary with workflow outputs:
    - If data_output is None or []: no data sent (only metadata)
    - If data_output is ["*"]: all workflow outputs (excluding images, unless explicitly named)
    - If data_output is ["field1", "field2"]: only those fields (including images if explicitly named)
    """

    serialized_output_data: Optional[Dict[str, Any]] = None
    video_metadata: Optional[WebRTCVideoMetadata] = None
    errors: List[str] = Field(default_factory=list)
    processing_complete: bool = False  # Signals end of video file processing
    termination_reason: Optional[str] = None

inference.core.interfaces.webrtc_worker.serializers

Classes

Functions

compress_image_for_webrtc

compress_image_for_webrtc(image)

Serialize image with low JPEG quality for efficient WebRTC transmission.

Source code in inference/core/interfaces/webrtc_worker/serializers.py
12
13
14
15
16
17
18
19
20
21
def compress_image_for_webrtc(image: WorkflowImageData) -> Dict[str, Any]:
    """Serialize image with low JPEG quality for efficient WebRTC transmission."""
    jpeg_bytes = encode_image_to_jpeg_bytes(
        image.numpy_image, jpeg_quality=WEBRTC_PREVIEW_FRAME_JPEG_QUALITY
    )
    return {
        "type": "base64",
        "value": base64.b64encode(jpeg_bytes).decode("ascii"),
        "video_metadata": image.video_metadata.dict() if image.video_metadata else None,
    }

serialize_for_webrtc

serialize_for_webrtc(value)

Serialize for WebRTC, compressing images with low JPEG quality.

Source code in inference/core/interfaces/webrtc_worker/serializers.py
24
25
26
27
28
29
30
31
32
def serialize_for_webrtc(value: Any) -> Any:
    """Serialize for WebRTC, compressing images with low JPEG quality."""
    if isinstance(value, WorkflowImageData):
        return compress_image_for_webrtc(value)
    if isinstance(value, dict):
        return {k: serialize_for_webrtc(v) for k, v in value.items()}
    if isinstance(value, list):
        return [serialize_for_webrtc(v) for v in value]
    return serialize_wildcard_kind(value)

inference.core.interfaces.webrtc_worker.utils

Classes

Functions

detect_image_output

detect_image_output(workflow_output)

Detect the first available image output field in workflow output.

Source code in inference/core/interfaces/webrtc_worker/utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def detect_image_output(
    workflow_output: Dict[str, Union[WorkflowImageData, Any]],
) -> Optional[str]:
    """Detect the first available image output field in workflow output."""
    for output_name in workflow_output.keys():
        if (
            get_frame_from_workflow_output(
                workflow_output=workflow_output,
                frame_output_key=output_name,
            )
            is not None
        ):
            return output_name
    return None

get_cv2_rotation_code

get_cv2_rotation_code(rotation)

Get OpenCV rotation code to correct a given rotation.

Parameters:

Name Type Description Default
rotation int

Rotation angle in degrees from metadata

required

Returns:

Type Description
Optional[int]

cv2 rotation constant or None if no correction needed

Source code in inference/core/interfaces/webrtc_worker/utils.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def get_cv2_rotation_code(rotation: int) -> Optional[int]:
    """Get OpenCV rotation code to correct a given rotation.

    Args:
        rotation: Rotation angle in degrees from metadata

    Returns:
        cv2 rotation constant or None if no correction needed
    """
    # The displaymatrix rotation indicates how the video is rotated.
    # To correct it, we apply the OPPOSITE rotation.
    if rotation in (-90, 270):
        return cv.ROTATE_90_CLOCKWISE
    elif rotation in (90, -270):
        return cv.ROTATE_90_COUNTERCLOCKWISE
    elif rotation in (180, -180):
        return cv.ROTATE_180
    return None

get_video_fps

get_video_fps(filepath)

Detect video FPS from container metadata.

Parameters:

Name Type Description Default
filepath str

Path to the video file

required

Returns:

Type Description
Optional[float]

FPS as float, or None if detection fails

Source code in inference/core/interfaces/webrtc_worker/utils.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def get_video_fps(filepath: str) -> Optional[float]:
    """Detect video FPS from container metadata.

    Args:
        filepath: Path to the video file

    Returns:
        FPS as float, or None if detection fails
    """
    import json
    import subprocess

    try:
        result = subprocess.run(
            [
                "ffprobe",
                "-v",
                "error",
                "-select_streams",
                "v:0",
                "-show_entries",
                "stream=r_frame_rate,avg_frame_rate",
                "-of",
                "json",
                filepath,
            ],
            capture_output=True,
            text=True,
            timeout=5,
        )
        if result.returncode == 0:
            data = json.loads(result.stdout)
            streams = data.get("streams", [])
            if streams:
                stream = streams[0]
                # Prefer avg_frame_rate (actual average) over r_frame_rate (container rate)
                for rate_key in ["avg_frame_rate", "r_frame_rate"]:
                    rate_str = stream.get(rate_key, "0/1")
                    if "/" in rate_str:
                        num, den = rate_str.split("/")
                        if int(den) != 0:
                            fps = int(num) / int(den)
                            if fps > 0:
                                logger.info(
                                    "Video FPS detected: %.2f from %s", fps, rate_key
                                )
                                return fps
        else:
            logger.warning("ffprobe FPS detection failed: %s", result.stderr.strip())
    except FileNotFoundError:
        logger.warning("ffprobe not available for FPS detection")
    except subprocess.TimeoutExpired:
        logger.warning("ffprobe timed out during FPS detection")
    except Exception as e:
        logger.warning("ffprobe FPS detection failed: %s", e)

    return None

get_video_rotation

get_video_rotation(filepath)

Detect video rotation from metadata (displaymatrix or rotate tag).

Parameters:

Name Type Description Default
filepath str

Path to the video file

required

Returns:

Type Description
int

Rotation in degrees (-90, 0, 90, 180, 270) or 0 if not found.

int

Negative values indicate counter-clockwise rotation.

Source code in inference/core/interfaces/webrtc_worker/utils.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def get_video_rotation(filepath: str) -> int:
    """Detect video rotation from metadata (displaymatrix or rotate tag).

    Args:
        filepath: Path to the video file

    Returns:
        Rotation in degrees (-90, 0, 90, 180, 270) or 0 if not found.
        Negative values indicate counter-clockwise rotation.
    """
    import json
    import subprocess

    try:
        # Use -show_streams which is compatible with all ffprobe versions
        result = subprocess.run(
            [
                "ffprobe",
                "-v",
                "error",
                "-select_streams",
                "v:0",
                "-show_streams",
                "-of",
                "json",
                filepath,
            ],
            capture_output=True,
            text=True,
            timeout=5,
        )
        if result.returncode == 0:
            data = json.loads(result.stdout)
            streams = data.get("streams", [])
            if streams:
                stream = streams[0]
                # Check displaymatrix side_data first
                for sd in stream.get("side_data_list", []):
                    if "rotation" in sd:
                        rotation = int(sd["rotation"])
                        logger.info("Video rotation detected: %d°", rotation)
                        return rotation
                # Fall back to rotate tag in stream tags
                rotate_str = stream.get("tags", {}).get("rotate", "0")
                rotation = int(rotate_str)
                if rotation != 0:
                    logger.info("Video rotation detected: %d°", rotation)
                    return rotation
        else:
            logger.warning("ffprobe failed: %s", result.stderr.strip())
    except FileNotFoundError:
        logger.warning("ffprobe not available")
    except subprocess.TimeoutExpired:
        logger.warning("ffprobe timed out")
    except Exception as e:
        logger.warning("ffprobe rotation detection failed: %s", e)

    return 0

parse_video_file_chunk

parse_video_file_chunk(message)

Parse video file chunk message.

Returns: (chunk_index, total_chunks, payload)

Source code in inference/core/interfaces/webrtc_worker/utils.py
196
197
198
199
200
201
202
203
204
def parse_video_file_chunk(message: bytes) -> Tuple[int, int, bytes]:
    """Parse video file chunk message.

    Returns: (chunk_index, total_chunks, payload)
    """
    if len(message) < VIDEO_FILE_HEADER_SIZE:
        raise ValueError(f"Message too short: {len(message)} bytes")
    chunk_index, total_chunks = struct.unpack("<II", message[:8])
    return chunk_index, total_chunks, message[8:]

rotate_video_frame

rotate_video_frame(frame, rotation_code)

Apply rotation to a video frame using OpenCV.

Parameters:

Name Type Description Default
frame VideoFrame

Input VideoFrame

required
rotation_code int

cv2 rotation constant (ROTATE_90_CLOCKWISE, etc.)

required

Returns:

Type Description
VideoFrame

Rotated VideoFrame

Source code in inference/core/interfaces/webrtc_worker/utils.py
382
383
384
385
386
387
388
389
390
391
392
393
394
def rotate_video_frame(frame: VideoFrame, rotation_code: int) -> VideoFrame:
    """Apply rotation to a video frame using OpenCV.

    Args:
        frame: Input VideoFrame
        rotation_code: cv2 rotation constant (ROTATE_90_CLOCKWISE, etc.)

    Returns:
        Rotated VideoFrame
    """
    img = frame.to_ndarray(format="bgr24")
    img = cv.rotate(img, rotation_code)
    return VideoFrame.from_ndarray(img, format="bgr24")

inference.core.interfaces.webrtc_worker.webrtc

Classes

VideoFrameProcessor

Base class for processing video frames through workflow.

Can be used independently for data-only processing (no video track output) or as a base for VideoTransformTrackWithLoop when video output is needed.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
class VideoFrameProcessor:
    """Base class for processing video frames through workflow.

    Can be used independently for data-only processing (no video track output)
    or as a base for VideoTransformTrackWithLoop when video output is needed.
    """

    def __init__(
        self,
        asyncio_loop: asyncio.AbstractEventLoop,
        workflow_configuration: WorkflowConfiguration,
        api_key: str,
        model_manager: Optional[ModelManager] = None,
        data_output: Optional[List[str]] = None,
        stream_output: Optional[str] = None,
        has_video_track: bool = True,
        declared_fps: float = 30,
        termination_date: Optional[datetime.datetime] = None,
        terminate_event: Optional[asyncio.Event] = None,
        heartbeat_callback: Optional[Callable[[], None]] = None,
        realtime_processing: bool = True,
        is_preview: bool = False,
    ):
        self._file_processing = False
        self._loop = asyncio_loop
        self._termination_date = termination_date
        self._terminate_event = terminate_event
        self.track: Optional[MediaStreamTrack] = None
        self._track_active: bool = False
        self._av_logging_set: bool = False
        self._received_frames = 0
        self._declared_fps = declared_fps
        self._stop_processing = False
        self._termination_reason: Optional[str] = None
        self._processing_complete_sent = False
        self.heartbeat_callback = heartbeat_callback

        self.has_video_track = has_video_track
        self.stream_output = stream_output
        self.data_channel: Optional[RTCDataChannel] = None

        # Video file upload support
        self.video_upload_handler: Optional[VideoFileUploadHandler] = None
        self._track_ready_event: asyncio.Event = asyncio.Event()
        self.realtime_processing = realtime_processing
        self._rotation_code: Optional[int] = None

        # Optional receiver-paced flow control (enabled only after first ACK is received)
        self._ack_last: int = 0
        # If ack=1 and window=4, server may produce/send up to frame 5.
        # Configurable via WEBRTC_DATACHANNEL_ACK_WINDOW env var.
        self._ack_window: int = WEBRTC_DATA_CHANNEL_ACK_WINDOW
        self._ack_event: asyncio.Event = asyncio.Event()

        if data_output is None:
            self.data_output = None
            self._data_mode = DataOutputMode.NONE
        elif isinstance(data_output, list):
            self.data_output = [f for f in data_output if f]
            if self.data_output == ["*"]:
                self._data_mode = DataOutputMode.ALL
            elif len(self.data_output) == 0:
                self._data_mode = DataOutputMode.NONE
            else:
                self._data_mode = DataOutputMode.SPECIFIC
        else:
            raise WebRTCConfigurationError(
                f"data_output must be list or None, got {type(data_output).__name__}"
            )

        self._validate_output_fields(workflow_configuration)

        self._inference_pipeline = InferencePipeline.init_with_workflow(
            video_reference=VideoFrameProducer,
            workflow_specification=workflow_configuration.workflow_specification,
            workspace_name=workflow_configuration.workspace_name,
            workflow_id=workflow_configuration.workflow_id,
            api_key=api_key,
            image_input_name=workflow_configuration.image_input_name,
            workflows_parameters=workflow_configuration.workflows_parameters,
            workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers,
            cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit,
            video_metadata_input_name=workflow_configuration.video_metadata_input_name,
            model_manager=model_manager,
            _is_preview=is_preview,
            workflow_version_id=workflow_configuration.workflow_version_id,
        )

    def set_track(self, track: MediaStreamTrack, rotation_code: Optional[int] = None):
        if not self.track:
            self.track = track
            self._rotation_code = rotation_code
            self._track_ready_event.set()

    async def close(self):
        self._track_active = False
        self._stop_processing = True
        # Clean up video upload handler if present
        if self.video_upload_handler is not None:
            await self.video_upload_handler.cleanup()

    def record_ack(self, ack: int) -> None:
        """Record cumulative ACK from the client.

        ACK semantics: client has fully handled all frames <= ack.
        Backwards compatible: pacing is disabled until we receive the first ACK.
        """
        try:
            ack_int = int(ack)
        except (TypeError, ValueError):
            logger.warning("Invalid ACK value: %s", ack)
            return
        if ack_int < 0:
            logger.warning("Invalid ACK value: %s", ack)
            return
        if ack_int > self._ack_last:
            if ack_int % 100 == 1:
                logger.info("ACK received: %s", ack_int)
            self._ack_last = ack_int
            self._ack_event.set()

    async def _wait_for_ack_window(self, next_frame_id: int) -> None:
        """Block frame production when too far ahead of client ACKs."""
        if self.realtime_processing or self._ack_last == 0:
            return

        wait_counter = 0
        while not self._stop_processing and next_frame_id > (
            self._ack_last + self._ack_window
        ):
            if self._check_termination():
                return
            if self.heartbeat_callback:
                self.heartbeat_callback()

            self._ack_event.clear()
            try:
                await asyncio.wait_for(self._ack_event.wait(), timeout=0.2)
            except asyncio.TimeoutError:
                wait_counter += 1
                if wait_counter % 5 == 1:
                    logger.info(
                        "Waiting for ACK window (next=%d, ack_last=%d, window=%d)",
                        next_frame_id,
                        self._ack_last,
                        self._ack_window,
                    )

    def _check_termination(self):
        """Check if we should terminate based on timeout.

        Does NOT set terminate_event — callers must call _signal_termination()
        after sending final data-channel messages to avoid a race with the
        cleanup task closing the peer connection.
        """
        if self._termination_date and self._termination_date < datetime.datetime.now():
            logger.info("Timeout reached, terminating inference pipeline")
            self._termination_reason = "timeout_reached"
            return True
        if self._terminate_event and self._terminate_event.is_set():
            logger.info("Terminate event set, terminating inference pipeline")
            return True
        return False

    def _signal_termination(self):
        if self._terminate_event:
            self._terminate_event.set()

    @staticmethod
    def serialize_outputs_sync(
        fields_to_send: List[str],
        workflow_output: Dict[str, Any],
        data_output_mode: DataOutputMode,
    ) -> Tuple[Dict[str, Any], List[str]]:
        """Serialize workflow outputs for WebRTC transmission."""
        serialized = {}
        serialization_errors = []

        for field_name in fields_to_send:
            if field_name not in workflow_output:
                serialization_errors.append(f"Output '{field_name}' not found")
                continue

            output_data = workflow_output[field_name]

            if data_output_mode == DataOutputMode.ALL and isinstance(
                output_data, WorkflowImageData
            ):
                continue

            try:
                serialized[field_name] = serialize_for_webrtc(output_data)
            except Exception as e:
                serialization_errors.append(f"{field_name}: {e}")
                serialized[field_name] = {"__serialization_error__": str(e)}
                logger.error("[SERIALIZE] Error: %s - %s", field_name, e)

        return serialized, serialization_errors

    async def _send_data_output(
        self,
        workflow_output: Dict[str, Any],
        frame_timestamp: datetime.datetime,
        frame: VideoFrame,
        errors: List[str],
    ):
        frame_id = self._received_frames

        if not self.data_channel or self.data_channel.readyState != "open":
            return

        video_metadata = WebRTCVideoMetadata(
            frame_id=frame_id,
            received_at=frame_timestamp.isoformat(),
            pts=frame.pts,
            time_base=frame.time_base,
            declared_fps=self._declared_fps,
            height=frame.height,
            width=frame.width,
        )

        webrtc_output = WebRTCOutput(
            serialized_output_data=None,
            video_metadata=video_metadata,
            errors=errors.copy(),
        )

        if self._data_mode == DataOutputMode.NONE:
            json_bytes = await asyncio.to_thread(
                lambda: json.dumps(webrtc_output.model_dump()).encode("utf-8")
            )
            await send_chunked_data(
                self.data_channel,
                frame_id,
                json_bytes,
                heartbeat_callback=self.heartbeat_callback,
            )
            return

        if self._data_mode == DataOutputMode.ALL:
            fields_to_send = list(workflow_output.keys())
        else:
            fields_to_send = self.data_output

        serialized_outputs, serialization_errors = await asyncio.to_thread(
            VideoFrameProcessor.serialize_outputs_sync,
            fields_to_send,
            workflow_output,
            self._data_mode,
        )

        webrtc_output.errors.extend(serialization_errors)
        if serialized_outputs:
            webrtc_output.serialized_output_data = serialized_outputs

        # TODO: use orjson
        json_bytes = await asyncio.to_thread(
            lambda: json.dumps(webrtc_output.model_dump(mode="json")).encode("utf-8")
        )

        if WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION:

            def compress_json():
                return gzip.compress(json_bytes, compresslevel=6)

            output_bytes = await asyncio.to_thread(compress_json)
        else:
            output_bytes = json_bytes

        success = await send_chunked_data(
            self.data_channel,
            frame_id,
            output_bytes,
            heartbeat_callback=self.heartbeat_callback,
        )
        if not success:
            logger.error("[SEND_OUTPUT] Frame %d failed", frame_id)

    async def _send_processing_complete(self):
        """Send final message indicating processing is complete.

        Also drains the data channel buffer to ensure delivery before the
        connection is closed.
        """
        if self._processing_complete_sent:
            return
        if not self.data_channel or self.data_channel.readyState != "open":
            return

        self._processing_complete_sent = True
        completion_output = WebRTCOutput(
            processing_complete=True,
            termination_reason=self._termination_reason,
            video_metadata=WebRTCVideoMetadata(
                frame_id=self._received_frames,
                received_at=datetime.datetime.now().isoformat(),
            ),
        )
        json_bytes = json.dumps(completion_output.model_dump()).encode("utf-8")
        await send_chunked_data(
            self.data_channel, self._received_frames + 1, json_bytes
        )
        if not await wait_for_buffer_drain(
            self.data_channel, timeout=2.0, low_threshold=0
        ):
            logger.warning(
                "Buffer drain timed out, processing_complete may not reach client"
            )

    async def process_frames_data_only(self):
        """Process frames for data extraction only, without video track output."""
        if not self._av_logging_set:
            av_logging.set_libav_level(av_logging.ERROR)
            self._av_logging_set = True

        try:
            while not self._stop_processing:
                await self._wait_for_ack_window(next_frame_id=self._received_frames + 1)
                if self._check_termination():
                    await self._send_processing_complete()
                    self._signal_termination()
                    break
                if self.heartbeat_callback:
                    self.heartbeat_callback()
                if not self.track or self.track.readyState == "ended":
                    break

                # Drain queue for realtime RTSP
                if (
                    isinstance(self.track, PlayerStreamTrack)
                    and self.realtime_processing
                ):
                    while self.track._queue.qsize() > 30:
                        self.track._queue.get_nowait()

                frame = await self.track.recv()
                self._received_frames += 1
                frame_timestamp = datetime.datetime.now()

                workflow_output, _, errors = await self._process_frame_async(
                    frame=frame,
                    frame_id=self._received_frames,
                    render_output=False,
                    include_errors_on_frame=False,
                )

                await self._send_data_output(
                    workflow_output, frame_timestamp, frame, errors
                )

        except asyncio.CancelledError as exc:
            # No one will catch this exception as it's executed in a create_task
            logger.info("[DATA_ONLY] Processing cancelled: %s", exc)
        except MediaStreamError as exc:
            logger.info("[DATA_ONLY] Media stream ended: %s", exc)
        except Exception as exc:
            logger.error(
                "[DATA_ONLY] Error at frame %d: %s", self._received_frames, exc
            )
        finally:
            await self._send_processing_complete()

    @staticmethod
    def _ensure_workflow_specification(
        workflow_configuration: WorkflowConfiguration, api_key: str
    ) -> None:
        has_specification = workflow_configuration.workflow_specification is not None
        has_workspace_and_workflow_id = (
            workflow_configuration.workspace_name is not None
            and workflow_configuration.workflow_id is not None
        )

        if not has_specification and not has_workspace_and_workflow_id:
            raise WebRTCConfigurationError(
                "Either 'workflow_specification' or both 'workspace_name' and 'workflow_id' must be provided"
            )

        if not has_specification and has_workspace_and_workflow_id:
            try:
                workflow_configuration.workflow_specification = (
                    get_workflow_specification(
                        api_key=api_key,
                        workspace_id=workflow_configuration.workspace_name,
                        workflow_id=workflow_configuration.workflow_id,
                        workflow_version_id=workflow_configuration.workflow_version_id,
                    )
                )
                workflow_configuration.workspace_name = None
                workflow_configuration.workflow_id = None
            except Exception as e:
                raise WebRTCConfigurationError(
                    f"Failed to fetch workflow specification from API: {str(e)}"
                )

    def _validate_output_fields(
        self, workflow_configuration: WorkflowConfiguration
    ) -> None:
        if workflow_configuration.workflow_specification is None:
            return

        workflow_outputs = workflow_configuration.workflow_specification.get(
            "outputs", []
        )
        available_output_names = [o.get("name") for o in workflow_outputs]

        if self._data_mode == DataOutputMode.SPECIFIC:
            invalid_fields = [
                field
                for field in self.data_output
                if field not in available_output_names
            ]
            if invalid_fields:
                raise WebRTCConfigurationError(
                    f"Invalid data_output fields: {invalid_fields}. "
                    f"Available workflow outputs: {available_output_names}"
                )

        if self.stream_output and self.stream_output not in available_output_names:
            raise WebRTCConfigurationError(
                f"Invalid stream_output field: '{self.stream_output}'. "
                f"Available workflow outputs: {available_output_names}"
            )

    async def _process_frame_async(
        self,
        frame: VideoFrame,
        frame_id: int,
        stream_output: Optional[str] = None,
        render_output: bool = True,
        include_errors_on_frame: bool = True,
    ) -> Tuple[Dict[str, Any], Optional[VideoFrame], List[str]]:
        """Async wrapper for process_frame using executor."""

        if self._rotation_code is not None:
            frame = rotate_video_frame(frame, self._rotation_code)

        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(
            None,
            process_frame,
            frame,
            frame_id,
            self._declared_fps,
            self._declared_fps,  # TODO: measure fps
            self._file_processing,
            self._inference_pipeline,
            stream_output,
            render_output,
            include_errors_on_frame,
        )
Functions
process_frames_data_only async
process_frames_data_only()

Process frames for data extraction only, without video track output.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
async def process_frames_data_only(self):
    """Process frames for data extraction only, without video track output."""
    if not self._av_logging_set:
        av_logging.set_libav_level(av_logging.ERROR)
        self._av_logging_set = True

    try:
        while not self._stop_processing:
            await self._wait_for_ack_window(next_frame_id=self._received_frames + 1)
            if self._check_termination():
                await self._send_processing_complete()
                self._signal_termination()
                break
            if self.heartbeat_callback:
                self.heartbeat_callback()
            if not self.track or self.track.readyState == "ended":
                break

            # Drain queue for realtime RTSP
            if (
                isinstance(self.track, PlayerStreamTrack)
                and self.realtime_processing
            ):
                while self.track._queue.qsize() > 30:
                    self.track._queue.get_nowait()

            frame = await self.track.recv()
            self._received_frames += 1
            frame_timestamp = datetime.datetime.now()

            workflow_output, _, errors = await self._process_frame_async(
                frame=frame,
                frame_id=self._received_frames,
                render_output=False,
                include_errors_on_frame=False,
            )

            await self._send_data_output(
                workflow_output, frame_timestamp, frame, errors
            )

    except asyncio.CancelledError as exc:
        # No one will catch this exception as it's executed in a create_task
        logger.info("[DATA_ONLY] Processing cancelled: %s", exc)
    except MediaStreamError as exc:
        logger.info("[DATA_ONLY] Media stream ended: %s", exc)
    except Exception as exc:
        logger.error(
            "[DATA_ONLY] Error at frame %d: %s", self._received_frames, exc
        )
    finally:
        await self._send_processing_complete()
record_ack
record_ack(ack)

Record cumulative ACK from the client.

ACK semantics: client has fully handled all frames <= ack. Backwards compatible: pacing is disabled until we receive the first ACK.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def record_ack(self, ack: int) -> None:
    """Record cumulative ACK from the client.

    ACK semantics: client has fully handled all frames <= ack.
    Backwards compatible: pacing is disabled until we receive the first ACK.
    """
    try:
        ack_int = int(ack)
    except (TypeError, ValueError):
        logger.warning("Invalid ACK value: %s", ack)
        return
    if ack_int < 0:
        logger.warning("Invalid ACK value: %s", ack)
        return
    if ack_int > self._ack_last:
        if ack_int % 100 == 1:
            logger.info("ACK received: %s", ack_int)
        self._ack_last = ack_int
        self._ack_event.set()
serialize_outputs_sync staticmethod
serialize_outputs_sync(
    fields_to_send, workflow_output, data_output_mode
)

Serialize workflow outputs for WebRTC transmission.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@staticmethod
def serialize_outputs_sync(
    fields_to_send: List[str],
    workflow_output: Dict[str, Any],
    data_output_mode: DataOutputMode,
) -> Tuple[Dict[str, Any], List[str]]:
    """Serialize workflow outputs for WebRTC transmission."""
    serialized = {}
    serialization_errors = []

    for field_name in fields_to_send:
        if field_name not in workflow_output:
            serialization_errors.append(f"Output '{field_name}' not found")
            continue

        output_data = workflow_output[field_name]

        if data_output_mode == DataOutputMode.ALL and isinstance(
            output_data, WorkflowImageData
        ):
            continue

        try:
            serialized[field_name] = serialize_for_webrtc(output_data)
        except Exception as e:
            serialization_errors.append(f"{field_name}: {e}")
            serialized[field_name] = {"__serialization_error__": str(e)}
            logger.error("[SERIALIZE] Error: %s - %s", field_name, e)

    return serialized, serialization_errors

VideoTransformTrackWithLoop

Bases: VideoStreamTrack, VideoFrameProcessor

Video track that processes frames through workflow and sends video back.

Inherits from both VideoStreamTrack (for WebRTC video track functionality) and VideoFrameProcessor (for workflow processing logic).

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
class VideoTransformTrackWithLoop(VideoStreamTrack, VideoFrameProcessor):
    """Video track that processes frames through workflow and sends video back.

    Inherits from both VideoStreamTrack (for WebRTC video track functionality)
    and VideoFrameProcessor (for workflow processing logic).
    """

    def __init__(
        self,
        asyncio_loop: asyncio.AbstractEventLoop,
        workflow_configuration: WorkflowConfiguration,
        api_key: str,
        model_manager: Optional[ModelManager] = None,
        data_output: Optional[List[str]] = None,
        stream_output: Optional[str] = None,
        has_video_track: bool = True,
        declared_fps: float = 30,
        termination_date: Optional[datetime.datetime] = None,
        terminate_event: Optional[asyncio.Event] = None,
        heartbeat_callback: Optional[Callable[[], None]] = None,
        realtime_processing: bool = True,
        is_preview: bool = False,
        *args,
        **kwargs,
    ):
        VideoStreamTrack.__init__(self, *args, **kwargs)
        VideoFrameProcessor.__init__(
            self,
            asyncio_loop=asyncio_loop,
            workflow_configuration=workflow_configuration,
            api_key=api_key,
            data_output=data_output,
            stream_output=stream_output,
            has_video_track=has_video_track,
            declared_fps=declared_fps,
            termination_date=termination_date,
            terminate_event=terminate_event,
            model_manager=model_manager,
            heartbeat_callback=heartbeat_callback,
            realtime_processing=realtime_processing,
            is_preview=is_preview,
        )

    async def _auto_detect_stream_output(
        self, frame: VideoFrame, frame_id: int
    ) -> None:
        workflow_output_for_detect, _, _ = await self._process_frame_async(
            frame=frame,
            frame_id=frame_id,
            render_output=False,
            include_errors_on_frame=False,
        )
        detected_output = detect_image_output(workflow_output_for_detect)
        if detected_output:
            self.stream_output = detected_output
            logger.info(f"Auto-detected stream_output: {detected_output}")
        else:
            logger.warning("No image output detected, will use fallback")
            self.stream_output = ""

    async def recv(self):
        # Silencing swscaler warnings in multi-threading environment
        if not self._av_logging_set:
            av_logging.set_libav_level(av_logging.ERROR)
            self._av_logging_set = True

        if self.heartbeat_callback:
            self.heartbeat_callback()

        # Wait for track to be ready (video file upload case)
        if self.track is None:
            logger.info("[RECV] Track is None, waiting for track_ready_event...")
            await self._track_ready_event.wait()
            if self.track is None:
                logger.error("[RECV] Track still None after wait!")
                raise MediaStreamError("Track not available after wait")

        # Optional ACK pacing: block producing the next frame if we're too far ahead.
        await self._wait_for_ack_window(next_frame_id=self._received_frames + 1)

        if self._check_termination():
            logger.warning("[RECV] Termination triggered, closing gracefully")
            await self._send_processing_complete()
            self._signal_termination()
            reason = self._termination_reason or "terminate_event"
            raise MediaStreamError(f"Processing terminated: {reason}")

        # Drain queue if using PlayerStreamTrack (RTSP/video file)
        if isinstance(self.track, PlayerStreamTrack) and self.realtime_processing:
            queue_size = self.track._queue.qsize()
            if queue_size > 30:
                drained = 0
                while self.track._queue.qsize() > 30:
                    self.track._queue.get_nowait()
                    drained += 1
                logger.info(
                    "[RECV] Drained %d frames from queue (was %d)", drained, queue_size
                )

        try:
            frame: VideoFrame = await self.track.recv()
        except MediaStreamError:
            logger.info("[RECV] Track ended after %d frames", self._received_frames)
            await self._send_processing_complete()
            raise

        self._received_frames += 1
        frame_id = self._received_frames
        frame_timestamp = datetime.datetime.now()

        if self.stream_output is None and frame_id == 1:
            await self._auto_detect_stream_output(frame, frame_id)

        workflow_output, new_frame, errors = await self._process_frame_async(
            frame=frame,
            frame_id=frame_id,
            stream_output=self.stream_output,
            render_output=True,
            include_errors_on_frame=True,
        )

        new_frame.pts = frame.pts
        new_frame.time_base = frame.time_base

        await self._send_data_output(workflow_output, frame_timestamp, frame, errors)

        if errors:
            logger.warning("[RECV] Frame %d errors: %s", frame_id, errors)

        return new_frame

Functions

create_chunked_binary_message

create_chunked_binary_message(
    frame_id, chunk_index, total_chunks, payload
)

Create a binary message with standard 12-byte header.

Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N] All integers are uint32 little-endian.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
83
84
85
86
87
88
89
90
91
92
def create_chunked_binary_message(
    frame_id: int, chunk_index: int, total_chunks: int, payload: bytes
) -> bytes:
    """Create a binary message with standard 12-byte header.

    Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N]
    All integers are uint32 little-endian.
    """
    header = struct.pack("<III", frame_id, chunk_index, total_chunks)
    return header + payload

send_chunked_data async

send_chunked_data(
    data_channel,
    frame_id,
    payload_bytes,
    chunk_size=CHUNK_SIZE,
    heartbeat_callback=None,
    buffer_timeout=120.0,
)

Send payload via data channel with chunking and backpressure.

We chunk large payloads because WebRTC data channels have message size limits. We apply backpressure (wait for buffer to drain) to avoid overwhelming the network and causing ICE connection failures.

Heads up: buffer_timeout needs to be higher than WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY! Otherwise we will timeout ourselves.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
async def send_chunked_data(
    data_channel: RTCDataChannel,
    frame_id: int,
    payload_bytes: bytes,
    chunk_size: int = CHUNK_SIZE,
    heartbeat_callback: Optional[Callable[[], None]] = None,
    buffer_timeout: float = 120.0,
) -> bool:
    """Send payload via data channel with chunking and backpressure.

    We chunk large payloads because WebRTC data channels have message size limits.
    We apply backpressure (wait for buffer to drain) to avoid overwhelming the
    network and causing ICE connection failures.

    Heads up: buffer_timeout needs to be higher than WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY!
    Otherwise we will timeout ourselves.
    """

    if buffer_timeout <= WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY:
        logger.warning(
            "[SEND_CHUNKED] buffer_timeout (%.2fs) <= WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY (%.2fs), "
            "this will likely cause immediate timeouts during buffer drain",
            buffer_timeout,
            WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY,
        )

    if data_channel.readyState != "open":
        return False

    payload_size = len(payload_bytes)
    total_chunks = (payload_size + chunk_size - 1) // chunk_size
    view = memoryview(payload_bytes)
    high_threshold = WEBRTC_DATA_CHANNEL_BUFFER_SIZE_LIMIT

    for chunk_index in range(total_chunks):
        if data_channel.readyState != "open":
            logger.error(
                "[SEND_CHUNKED] Channel closed at chunk %d/%d",
                chunk_index,
                total_chunks,
            )
            return False

        start = chunk_index * chunk_size
        end = min(start + chunk_size, payload_size)
        chunk_data = view[start:end]

        message = create_chunked_binary_message(
            frame_id, chunk_index, total_chunks, chunk_data
        )

        if data_channel.bufferedAmount > high_threshold:
            if not await wait_for_buffer_drain(
                data_channel, buffer_timeout, heartbeat_callback
            ):
                logger.error(
                    "[SEND_CHUNKED] Buffer drain failed at chunk %d/%d",
                    chunk_index,
                    total_chunks,
                )
                return False

        data_channel.send(message)

        if heartbeat_callback:
            heartbeat_callback()
        await asyncio.sleep(0.001)

    return True

wait_for_buffer_drain async

wait_for_buffer_drain(
    data_channel,
    timeout=30.0,
    heartbeat_callback=None,
    low_threshold=None,
)

Wait for data channel buffer to drain below threshold, with timeout.

We use a low threshold (1/4 of limit) instead of just below the limit to avoid hysteresis - constantly triggering this wait after sending just a few chunks.

And we wait WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY to avoid starving the event loop.

Source code in inference/core/interfaces/webrtc_worker/webrtc.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
async def wait_for_buffer_drain(
    data_channel: RTCDataChannel,
    timeout: float = 30.0,
    heartbeat_callback: Optional[Callable[[], None]] = None,
    low_threshold: Optional[int] = None,
) -> bool:
    """Wait for data channel buffer to drain below threshold, with timeout.

    We use a low threshold (1/4 of limit) instead of just below the limit to avoid
    hysteresis - constantly triggering this wait after sending just a few chunks.

    And we wait WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY to avoid starving the
    event loop.
    """
    if low_threshold is None:
        low_threshold = WEBRTC_DATA_CHANNEL_BUFFER_SIZE_LIMIT // 4

    start_time = asyncio.get_event_loop().time()

    while data_channel.bufferedAmount > low_threshold:
        elapsed = asyncio.get_event_loop().time() - start_time
        if elapsed > timeout:
            logger.error("[BUFFER_DRAIN] Timeout after %.1fs", timeout)
            return False
        if data_channel.readyState != "open":
            logger.error("[BUFFER_DRAIN] Channel closed: %s", data_channel.readyState)
            return False
        if heartbeat_callback:
            heartbeat_callback()
        await asyncio.sleep(WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY)

    return True

core/interfaces/webrtc_worker/sources

inference.core.interfaces.webrtc_worker.sources.file

Video file source for WebRTC - handles uploaded video files.

Classes

ThreadedVideoFileTrack

Bases: MediaStreamTrack

Video track that decodes frames from a file in a background thread.

Uses a dedicated thread with a queue to avoid deadlocks with the event loop.

Source code in inference/core/interfaces/webrtc_worker/sources/file.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class ThreadedVideoFileTrack(MediaStreamTrack):
    """Video track that decodes frames from a file in a background thread.

    Uses a dedicated thread with a queue to avoid deadlocks with the event loop.
    """

    kind = "video"

    def __init__(self, filepath: str, queue_size: int = 60):
        # TODO: add parameter queue size in settings
        super().__init__()
        self._queue = queue.Queue(maxsize=queue_size)
        self._stop_event = threading.Event()
        self._decode_thread = threading.Thread(
            target=_decode_worker,
            args=(filepath, self._queue, self._stop_event),
            daemon=True,
        )
        self._decode_thread.start()

    async def recv(self) -> VideoFrame:
        while True:
            try:
                data = self._queue.get_nowait()
                break
            except queue.Empty:
                await asyncio.sleep(0.001)

        if data is None:
            self.stop()
            raise MediaStreamError("End of video file")
        if isinstance(data, dict):
            logger.error("[ThreadedVideoTrack] Decode error: %s", data)
            self.stop()
            raise MediaStreamError(data.get("error", "Unknown decode error"))

        return data

    def stop(self):
        super().stop()
        self._stop_event.set()

VideoFileUploadHandler

Handles video file uploads via data channel.

Protocol: [chunk_index:u32][total_chunks:u32][payload] Auto-completes when all chunks received.

Source code in inference/core/interfaces/webrtc_worker/sources/file.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class VideoFileUploadHandler:
    """Handles video file uploads via data channel.

    Protocol: [chunk_index:u32][total_chunks:u32][payload]
    Auto-completes when all chunks received.
    """

    def __init__(self):
        self._chunks: Dict[int, bytes] = {}
        self._total_chunks: Optional[int] = None
        self._temp_file_path: Optional[str] = None
        self._state = VideoFileUploadState.IDLE
        self.upload_complete_event = asyncio.Event()

    @property
    def temp_file_path(self) -> Optional[str]:
        return self._temp_file_path

    def handle_chunk(self, chunk_index: int, total_chunks: int, data: bytes) -> None:
        """Handle a chunk. Auto-completes when all chunks received."""
        # TODO: we need to refactor this...
        if self._total_chunks is None:
            self._total_chunks = total_chunks
            self._state = VideoFileUploadState.UPLOADING

        self._chunks[chunk_index] = data

        if len(self._chunks) == total_chunks:
            self._write_to_temp_file()
            self._state = VideoFileUploadState.COMPLETE
            self.upload_complete_event.set()

    def _write_to_temp_file(self) -> None:
        """Reassemble chunks and write to temp file."""
        import tempfile

        # TODO: we need to refactor this...
        with tempfile.NamedTemporaryFile(mode="wb", suffix=".mp4", delete=False) as f:
            for i in range(self._total_chunks):
                f.write(self._chunks[i])
            self._temp_file_path = f.name

        self._chunks.clear()

    def try_start_processing(self) -> Optional[str]:
        """Check if upload complete and transition to PROCESSING. Returns path or None."""
        if self._state == VideoFileUploadState.COMPLETE:
            self._state = VideoFileUploadState.PROCESSING
            return self._temp_file_path
        return None

    async def cleanup(self) -> None:
        """Clean up temp file."""
        # TODO: we need to refactor this...
        if self._temp_file_path:
            import os

            path = self._temp_file_path
            self._temp_file_path = None
            try:
                await asyncio.to_thread(os.unlink, path)
            except Exception:
                pass
Functions
cleanup async
cleanup()

Clean up temp file.

Source code in inference/core/interfaces/webrtc_worker/sources/file.py
161
162
163
164
165
166
167
168
169
170
171
172
async def cleanup(self) -> None:
    """Clean up temp file."""
    # TODO: we need to refactor this...
    if self._temp_file_path:
        import os

        path = self._temp_file_path
        self._temp_file_path = None
        try:
            await asyncio.to_thread(os.unlink, path)
        except Exception:
            pass
handle_chunk
handle_chunk(chunk_index, total_chunks, data)

Handle a chunk. Auto-completes when all chunks received.

Source code in inference/core/interfaces/webrtc_worker/sources/file.py
128
129
130
131
132
133
134
135
136
137
138
139
140
def handle_chunk(self, chunk_index: int, total_chunks: int, data: bytes) -> None:
    """Handle a chunk. Auto-completes when all chunks received."""
    # TODO: we need to refactor this...
    if self._total_chunks is None:
        self._total_chunks = total_chunks
        self._state = VideoFileUploadState.UPLOADING

    self._chunks[chunk_index] = data

    if len(self._chunks) == total_chunks:
        self._write_to_temp_file()
        self._state = VideoFileUploadState.COMPLETE
        self.upload_complete_event.set()
try_start_processing
try_start_processing()

Check if upload complete and transition to PROCESSING. Returns path or None.

Source code in inference/core/interfaces/webrtc_worker/sources/file.py
154
155
156
157
158
159
def try_start_processing(self) -> Optional[str]:
    """Check if upload complete and transition to PROCESSING. Returns path or None."""
    if self._state == VideoFileUploadState.COMPLETE:
        self._state = VideoFileUploadState.PROCESSING
        return self._temp_file_path
    return None

core/logging

inference.core.logging.memory_handler

In-memory logging handler for dashboard log viewing.

This module provides a custom logging handler that stores log records in memory for retrieval via the /logs API endpoint. It's designed to be used when ENABLE_IN_MEMORY_LOGS environment variable is set to 'true'.

Classes

MemoryLogHandler

Bases: Handler

Custom log handler that stores log records in memory for dashboard access

Source code in inference/core/logging/memory_handler.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class MemoryLogHandler(logging.Handler):
    """Custom log handler that stores log records in memory for dashboard access"""

    def emit(self, record):
        try:
            # Format the log entry for JSON serialization
            log_entry = {
                "timestamp": datetime.fromtimestamp(record.created).isoformat(),
                "level": record.levelname,
                "logger": record.name,
                "message": self.format(record),
                "module": record.module or "",
                "line": record.lineno,
            }

            with _log_lock:
                _log_entries.append(log_entry)
        except Exception:
            # Silently handle any errors in logging to prevent recursion
            pass

Functions

get_recent_logs

get_recent_logs(limit=100, level=None, since=None)

Get recent log entries from memory

Source code in inference/core/logging/memory_handler.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def get_recent_logs(
    limit: int = 100, level: str = None, since: str = None
) -> List[Dict[str, Any]]:
    """Get recent log entries from memory"""
    with _log_lock:
        logs = list(_log_entries)

    # Filter by log level if specified
    if level:
        level_upper = level.upper()
        logs = [log for log in logs if log["level"] == level_upper]

    # Filter by timestamp if specified
    if since:
        try:
            since_dt = datetime.fromisoformat(since.replace("Z", "+00:00"))
            logs = [
                log
                for log in logs
                if datetime.fromisoformat(log["timestamp"]) > since_dt
            ]
        except ValueError:
            pass  # Invalid since timestamp, ignore filter

    # Limit results
    return logs[-limit:] if limit else logs

setup_memory_logging

setup_memory_logging()

Set up memory logging handler for the current logger hierarchy

Source code in inference/core/logging/memory_handler.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def setup_memory_logging() -> None:
    """Set up memory logging handler for the current logger hierarchy"""
    if not is_memory_logging_enabled():
        return
    logger.info("Setting up memory logging")
    memory_handler = MemoryLogHandler()
    memory_handler.setLevel(logging.DEBUG)  # Capture all levels
    memory_formatter = logging.Formatter(
        "%(asctime)s %(levelname)s %(name)s: %(message)s"
    )
    memory_handler.setFormatter(memory_formatter)

    # Add to root logger to capture all logs immediately
    root_logger = logging.getLogger()
    if memory_handler not in root_logger.handlers:
        root_logger.addHandler(memory_handler)

    # Specifically add to uvicorn.access logger to ensure access logs are captured now
    access_logger = logging.getLogger("uvicorn.access")
    if memory_handler not in access_logger.handlers:
        access_logger.addHandler(memory_handler)

    # Also patch uvicorn's default LOGGING_CONFIG so when uvicorn applies dictConfig,
    # our in-memory handler remains attached
    global _uvicorn_config_patched
    if not _uvicorn_config_patched:
        try:
            from uvicorn.config import LOGGING_CONFIG as UVICORN_LOGGING_CONFIG

            # Modify in-place (safe: uvicorn makes a deep copy later)
            log_config = UVICORN_LOGGING_CONFIG

            log_config.setdefault("formatters", {})
            if "default" not in log_config["formatters"]:
                log_config["formatters"]["default"] = {
                    "()": "uvicorn.logging.DefaultFormatter",
                    "fmt": "%(levelprefix)s %(message)s",
                    "use_colors": None,
                }

            log_config.setdefault("handlers", {})["inmemory"] = {
                "class": "inference.core.logging.memory_handler.MemoryLogHandler",
                "level": "DEBUG",
                "formatter": "default",
            }

            log_config.setdefault("loggers", {})
            log_config["loggers"].setdefault(
                "uvicorn.access",
                {
                    "handlers": ["default"],
                    "level": "INFO",
                    "propagate": False,
                },
            )
            if "inmemory" not in log_config["loggers"]["uvicorn.access"]["handlers"]:
                log_config["loggers"]["uvicorn.access"]["handlers"].append("inmemory")

            log_config["loggers"].setdefault(
                "uvicorn", {"handlers": ["default"], "level": "INFO"}
            )
            log_config["loggers"].setdefault("uvicorn.error", {"level": "INFO"})

            root_cfg = log_config.setdefault(
                "root", {"handlers": ["default"], "level": "INFO"}
            )
            if "inmemory" not in root_cfg.get("handlers", []):
                root_cfg.setdefault("handlers", []).append("inmemory")

            _uvicorn_config_patched = True
            logger.info("Patched uvicorn LOGGING_CONFIG to include MemoryLogHandler")
        except Exception:
            # Avoid hard failure if uvicorn is not available
            pass

    return memory_handler

core/managers

Model lifecycle managers: loading, unloading, registry, and resolution.

inference.core.managers.base

Classes

ModelManager

Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method.

Source code in inference/core/managers/base.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class ModelManager:
    """Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method."""

    def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None):
        self.model_registry = model_registry
        self._models: Dict[str, Model] = models if models is not None else {}
        self.pingback = None
        self._state_lock = Lock()
        self._models_state_locks: Dict[str, Lock] = {}

    def init_pingback(self):
        """Initializes pingback mechanism."""
        self.num_errors = 0  # in the device
        self.uuid = ROBOFLOW_SERVER_UUID
        if METRICS_ENABLED:
            self.pingback = PingbackInfo(self)
            self.pingback.start()

    def add_model(
        self,
        model_id: str,
        api_key: str,
        model_id_alias: Optional[str] = None,
        endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ) -> None:
        """Adds a new model to the manager.

        Args:
            model_id (str): The identifier of the model.
            model (Model): The model instance.
            endpoint_type (ModelEndpointType, optional): The endpoint type to use for the model.
        """
        if MODELS_CACHE_AUTH_ENABLED:
            if not _check_if_api_key_has_access_to_model(
                api_key=api_key,
                model_id=model_id,
                endpoint_type=endpoint_type,
                countinference=countinference,
                service_secret=service_secret,
            ):
                raise RoboflowAPINotAuthorizedError(
                    f"API key {api_key} does not have access to model {model_id}"
                )

        logger.debug(
            f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}"
        )
        resolved_identifier = model_id if model_id_alias is None else model_id_alias
        ids_collector = request_model_ids.get(None)
        if ids_collector is not None:
            ids_collector.add(resolved_identifier)
        model_lock = self._get_lock_for_a_model(model_id=resolved_identifier)
        with acquire_with_timeout(lock=model_lock) as acquired:
            if not acquired:
                # if failed to acquire - then in use, no need to purge lock
                raise ModelManagerLockAcquisitionError(
                    f"Could not acquire lock for model with id={resolved_identifier}."
                )
            if resolved_identifier in self._models:
                logger.debug(
                    f"ModelManager - model with model_id={resolved_identifier} is already loaded."
                )
                return
            try:
                logger.debug("ModelManager - model initialisation...")
                t_load_start = time.perf_counter()
                model_class = self.model_registry.get_model(
                    resolved_identifier,
                    api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                model = model_class(
                    model_id=model_id,
                    api_key=api_key,
                    countinference=countinference,
                    service_secret=service_secret,
                )

                # Pass countinference and service_secret to download_model_artifacts_from_roboflow_api if available
                if (
                    hasattr(model, "download_model_artifacts_from_roboflow_api")
                    and INTERNAL_WEIGHTS_URL_SUFFIX == "serverless"
                ):
                    # Only pass these parameters if INTERNAL_WEIGHTS_URL_SUFFIX is "serverless"
                    if (
                        hasattr(model, "cache_model_artefacts")
                        and not model.has_model_metadata
                    ):
                        # Override the download_model_artifacts_from_roboflow_api method with parameters
                        original_method = (
                            model.download_model_artifacts_from_roboflow_api
                        )
                        model.download_model_artifacts_from_roboflow_api = (
                            lambda: original_method(
                                countinference=countinference,
                                service_secret=service_secret,
                            )
                        )

                load_time = time.perf_counter() - t_load_start
                logger.debug(
                    f"ModelManager - model successfully loaded in {load_time:.2f}s."
                )
                self._models[resolved_identifier] = model
                collector = model_load_info.get(None)
                if collector is not None:
                    collector.record(model_id=resolved_identifier, load_time=load_time)
            except Exception as error:
                self._dispose_model_lock(model_id=resolved_identifier)
                raise error

    def check_for_model(self, model_id: str) -> None:
        """Checks whether the model with the given ID is in the manager.

        Args:
            model_id (str): The identifier of the model.

        Raises:
            InferenceModelNotFound: If the model is not found in the manager.
        """
        if model_id not in self:
            raise InferenceModelNotFound(f"Model with id {model_id} not loaded.")

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Runs inference on the specified model with the given request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        logger.debug(
            f"ModelManager - inference from request started for model_id={model_id}."
        )
        enable_model_monitoring = not getattr(
            request, "disable_model_monitoring", False
        )
        if METRICS_ENABLED and self.pingback and enable_model_monitoring:
            logger.debug("ModelManager - setting pingback fallback api key...")
            self.pingback.fallback_api_key = request.api_key
        try:
            rtn_val = await self.model_infer(
                model_id=model_id, request=request, **kwargs
            )
            logger.debug(
                f"ModelManager - inference from request finished for model_id={model_id}."
            )
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
                try:
                    logger.debug(
                        f"ModelManager - caching inference request started for model_id={model_id}"
                    )
                    cache.zadd(
                        f"models",
                        value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    if (
                        hasattr(request, "image")
                        and hasattr(request.image, "type")
                        and request.image.type == "numpy"
                    ):
                        request.image.value = str(request.image.value)
                    cache.zadd(
                        f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                        value=to_cachable_inference_item(request, rtn_val),
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    logger.debug(
                        f"ModelManager - caching inference request finished for model_id={model_id}"
                    )
                except Exception as cache_error:
                    logger.warning(
                        f"Failed to cache inference data for model {model_id}: {cache_error}"
                    )
            return rtn_val
        except Exception as e:
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
                try:
                    cache.zadd(
                        f"models",
                        value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    cache.zadd(
                        f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                        value={
                            "request": jsonable_encoder(
                                request.dict(exclude={"image", "subject", "prompt"})
                            ),
                            "error": str(e),
                        },
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                except Exception as cache_error:
                    logger.warning(
                        f"Failed to cache error data for model {model_id}: {cache_error}"
                    )
            raise

    def infer_from_request_sync(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Runs inference on the specified model with the given request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        logger.debug(
            f"ModelManager - inference from request started for model_id={model_id}."
        )
        enable_model_monitoring = not getattr(
            request, "disable_model_monitoring", False
        )
        if METRICS_ENABLED and self.pingback and enable_model_monitoring:
            logger.debug("ModelManager - setting pingback fallback api key...")
            self.pingback.fallback_api_key = request.api_key
        try:
            rtn_val = self.model_infer_sync(
                model_id=model_id, request=request, **kwargs
            )
            logger.debug(
                f"ModelManager - inference from request finished for model_id={model_id}."
            )
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
                try:
                    logger.debug(
                        f"ModelManager - caching inference request started for model_id={model_id}"
                    )
                    cache.zadd(
                        f"models",
                        value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    if (
                        hasattr(request, "image")
                        and hasattr(request.image, "type")
                        and request.image.type == "numpy"
                    ):
                        request.image.value = str(request.image.value)
                    cache.zadd(
                        f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                        value=to_cachable_inference_item(request, rtn_val),
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    logger.debug(
                        f"ModelManager - caching inference request finished for model_id={model_id}"
                    )
                except Exception as cache_error:
                    logger.warning(
                        f"Failed to cache inference data for model {model_id}: {cache_error}"
                    )
            return rtn_val
        except Exception as e:
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
                try:
                    cache.zadd(
                        f"models",
                        value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                    cache.zadd(
                        f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                        value={
                            "request": jsonable_encoder(
                                request.dict(exclude={"image", "subject", "prompt"})
                            ),
                            "error": str(e),
                        },
                        score=finish_time,
                        expire=METRICS_INTERVAL * 2,
                    )
                except Exception as cache_error:
                    logger.warning(
                        f"Failed to cache error data for model {model_id}: {cache_error}"
                    )
            raise

    async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs):
        model = self._get_model_reference(model_id=model_id)
        return model.infer_from_request(request)

    def model_infer_sync(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> Union[List[InferenceResponse], InferenceResponse]:
        model = self._get_model_reference(model_id=model_id)
        return model.infer_from_request(request)

    def make_response(
        self, model_id: str, predictions: List[List[float]], *args, **kwargs
    ) -> InferenceResponse:
        """Creates a response object from the model's predictions.

        Args:
            model_id (str): The identifier of the model.
            predictions (List[List[float]]): The model's predictions.

        Returns:
            InferenceResponse: The created response object.
        """
        model = self._get_model_reference(model_id=model_id)
        return model.make_response(predictions, *args, **kwargs)

    def postprocess(
        self,
        model_id: str,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        *args,
        **kwargs,
    ) -> List[List[float]]:
        """Processes the model's predictions after inference.

        Args:
            model_id (str): The identifier of the model.
            predictions (np.ndarray): The model's predictions.

        Returns:
            List[List[float]]: The post-processed predictions.
        """
        model = self._get_model_reference(model_id=model_id)
        return model.postprocess(
            predictions, preprocess_return_metadata, *args, **kwargs
        )

    def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
        """Runs prediction on the specified model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            np.ndarray: The predictions from the model.
        """
        model = self._get_model_reference(model_id=model_id)
        model.metrics["num_inferences"] += 1
        tic = time.perf_counter()
        res = model.predict(*args, **kwargs)
        toc = time.perf_counter()
        model.metrics["avg_inference_time"] += toc - tic
        return res

    def preprocess(
        self, model_id: str, request: InferenceRequest
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        """Preprocesses the request before inference.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to preprocess.

        Returns:
            Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.
        """
        model = self._get_model_reference(model_id=model_id)
        return model.preprocess(**request.dict())

    def get_class_names(self, model_id):
        """Retrieves the class names for a given model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            List[str]: The class names of the model.
        """
        model = self._get_model_reference(model_id=model_id)
        return model.class_names

    def get_task_type(self, model_id: str, api_key: str = None) -> str:
        """Retrieves the task type for a given model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            str: The task type of the model.
        """
        model = self._get_model_reference(model_id=model_id)
        return model.task_type

    def remove(self, model_id: str, delete_from_disk: bool = True) -> None:
        """Removes a model from the manager.

        Args:
            model_id (str): The identifier of the model.
        """
        try:
            logger.debug(f"Removing model {model_id} from base model manager")
            model_lock = self._get_lock_for_a_model(model_id=model_id)
            with acquire_with_timeout(lock=model_lock) as acquired:
                if not acquired:
                    raise ModelManagerLockAcquisitionError(
                        f"Could not acquire lock for model with id={model_id}."
                    )
                if model_id not in self._models:
                    return None
                self._models[model_id].clear_cache(delete_from_disk=delete_from_disk)
                del self._models[model_id]
                self._dispose_model_lock(model_id=model_id)
                try_releasing_cuda_memory()
        except InferenceModelNotFound:
            logger.warning(
                f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..."
            )

    def clear(self) -> None:
        """Removes all models from the manager."""
        model_ids = list(self.keys())
        for model_id in model_ids:
            self.remove(model_id)

    def _get_model_reference(self, model_id: str) -> Model:
        try:
            return self._models[model_id]
        except KeyError as error:
            raise InferenceModelNotFound(
                f"Model with id {model_id} not loaded."
            ) from error

    def __contains__(self, model_id: str) -> bool:
        """Checks if the model is contained in the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            bool: Whether the model is in the manager.
        """
        return model_id in self._models

    def __getitem__(self, key: str) -> Model:
        """Retrieve a model from the manager by key.

        Args:
            key (str): The identifier of the model.

        Returns:
            Model: The model corresponding to the key.
        """
        return self._get_model_reference(model_id=key)

    def __len__(self) -> int:
        """Retrieve the number of models in the manager.

        Returns:
            int: The number of models in the manager.
        """
        return len(self._models)

    def keys(self):
        """Retrieve the keys (model identifiers) from the manager.

        Returns:
            List[str]: The keys of the models in the manager.
        """
        return self._models.keys()

    def models(self) -> Dict[str, Model]:
        """Retrieve the models dictionary from the manager.

        Returns:
            Dict[str, Model]: The keys of the models in the manager.
        """
        return self._models

    def describe_models(self) -> List[ModelDescription]:
        return [
            ModelDescription(
                model_id=model_id,
                task_type=model.task_type,
                batch_size=getattr(model, "batch_size", None),
                input_width=getattr(model, "img_size_w", None),
                input_height=getattr(model, "img_size_h", None),
            )
            for model_id, model in self._models.items()
        ]

    def _get_lock_for_a_model(self, model_id: str) -> Lock:
        with acquire_with_timeout(lock=self._state_lock) as acquired:
            if not acquired:
                raise ModelManagerLockAcquisitionError(
                    "Could not acquire lock on Model Manager state to retrieve model lock."
                )
            if model_id not in self._models_state_locks:
                self._models_state_locks[model_id] = Lock()
            return self._models_state_locks[model_id]

    def _dispose_model_lock(self, model_id: str) -> None:
        with acquire_with_timeout(lock=self._state_lock) as acquired:
            if not acquired:
                raise ModelManagerLockAcquisitionError(
                    "Could not acquire lock on Model Manager state to dispose model lock."
                )
            if model_id not in self._models_state_locks:
                return None
            del self._models_state_locks[model_id]
Functions
__contains__
__contains__(model_id)

Checks if the model is contained in the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Name Type Description
bool bool

Whether the model is in the manager.

Source code in inference/core/managers/base.py
487
488
489
490
491
492
493
494
495
496
def __contains__(self, model_id: str) -> bool:
    """Checks if the model is contained in the manager.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        bool: Whether the model is in the manager.
    """
    return model_id in self._models
__getitem__
__getitem__(key)

Retrieve a model from the manager by key.

Parameters:

Name Type Description Default
key str

The identifier of the model.

required

Returns:

Name Type Description
Model Model

The model corresponding to the key.

Source code in inference/core/managers/base.py
498
499
500
501
502
503
504
505
506
507
def __getitem__(self, key: str) -> Model:
    """Retrieve a model from the manager by key.

    Args:
        key (str): The identifier of the model.

    Returns:
        Model: The model corresponding to the key.
    """
    return self._get_model_reference(model_id=key)
__len__
__len__()

Retrieve the number of models in the manager.

Returns:

Name Type Description
int int

The number of models in the manager.

Source code in inference/core/managers/base.py
509
510
511
512
513
514
515
def __len__(self) -> int:
    """Retrieve the number of models in the manager.

    Returns:
        int: The number of models in the manager.
    """
    return len(self._models)
add_model
add_model(
    model_id,
    api_key,
    model_id_alias=None,
    endpoint_type=ModelEndpointType.ORT,
    countinference=None,
    service_secret=None,
)

Adds a new model to the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
model Model

The model instance.

required
endpoint_type ModelEndpointType

The endpoint type to use for the model.

ORT
Source code in inference/core/managers/base.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def add_model(
    self,
    model_id: str,
    api_key: str,
    model_id_alias: Optional[str] = None,
    endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
) -> None:
    """Adds a new model to the manager.

    Args:
        model_id (str): The identifier of the model.
        model (Model): The model instance.
        endpoint_type (ModelEndpointType, optional): The endpoint type to use for the model.
    """
    if MODELS_CACHE_AUTH_ENABLED:
        if not _check_if_api_key_has_access_to_model(
            api_key=api_key,
            model_id=model_id,
            endpoint_type=endpoint_type,
            countinference=countinference,
            service_secret=service_secret,
        ):
            raise RoboflowAPINotAuthorizedError(
                f"API key {api_key} does not have access to model {model_id}"
            )

    logger.debug(
        f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}"
    )
    resolved_identifier = model_id if model_id_alias is None else model_id_alias
    ids_collector = request_model_ids.get(None)
    if ids_collector is not None:
        ids_collector.add(resolved_identifier)
    model_lock = self._get_lock_for_a_model(model_id=resolved_identifier)
    with acquire_with_timeout(lock=model_lock) as acquired:
        if not acquired:
            # if failed to acquire - then in use, no need to purge lock
            raise ModelManagerLockAcquisitionError(
                f"Could not acquire lock for model with id={resolved_identifier}."
            )
        if resolved_identifier in self._models:
            logger.debug(
                f"ModelManager - model with model_id={resolved_identifier} is already loaded."
            )
            return
        try:
            logger.debug("ModelManager - model initialisation...")
            t_load_start = time.perf_counter()
            model_class = self.model_registry.get_model(
                resolved_identifier,
                api_key,
                countinference=countinference,
                service_secret=service_secret,
            )

            model = model_class(
                model_id=model_id,
                api_key=api_key,
                countinference=countinference,
                service_secret=service_secret,
            )

            # Pass countinference and service_secret to download_model_artifacts_from_roboflow_api if available
            if (
                hasattr(model, "download_model_artifacts_from_roboflow_api")
                and INTERNAL_WEIGHTS_URL_SUFFIX == "serverless"
            ):
                # Only pass these parameters if INTERNAL_WEIGHTS_URL_SUFFIX is "serverless"
                if (
                    hasattr(model, "cache_model_artefacts")
                    and not model.has_model_metadata
                ):
                    # Override the download_model_artifacts_from_roboflow_api method with parameters
                    original_method = (
                        model.download_model_artifacts_from_roboflow_api
                    )
                    model.download_model_artifacts_from_roboflow_api = (
                        lambda: original_method(
                            countinference=countinference,
                            service_secret=service_secret,
                        )
                    )

            load_time = time.perf_counter() - t_load_start
            logger.debug(
                f"ModelManager - model successfully loaded in {load_time:.2f}s."
            )
            self._models[resolved_identifier] = model
            collector = model_load_info.get(None)
            if collector is not None:
                collector.record(model_id=resolved_identifier, load_time=load_time)
        except Exception as error:
            self._dispose_model_lock(model_id=resolved_identifier)
            raise error
check_for_model
check_for_model(model_id)

Checks whether the model with the given ID is in the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Raises:

Type Description
InferenceModelNotFound

If the model is not found in the manager.

Source code in inference/core/managers/base.py
159
160
161
162
163
164
165
166
167
168
169
def check_for_model(self, model_id: str) -> None:
    """Checks whether the model with the given ID is in the manager.

    Args:
        model_id (str): The identifier of the model.

    Raises:
        InferenceModelNotFound: If the model is not found in the manager.
    """
    if model_id not in self:
        raise InferenceModelNotFound(f"Model with id {model_id} not loaded.")
clear
clear()

Removes all models from the manager.

Source code in inference/core/managers/base.py
473
474
475
476
477
def clear(self) -> None:
    """Removes all models from the manager."""
    model_ids = list(self.keys())
    for model_id in model_ids:
        self.remove(model_id)
get_class_names
get_class_names(model_id)

Retrieves the class names for a given model.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Type Description

List[str]: The class names of the model.

Source code in inference/core/managers/base.py
424
425
426
427
428
429
430
431
432
433
434
def get_class_names(self, model_id):
    """Retrieves the class names for a given model.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        List[str]: The class names of the model.
    """
    model = self._get_model_reference(model_id=model_id)
    return model.class_names
get_task_type
get_task_type(model_id, api_key=None)

Retrieves the task type for a given model.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Name Type Description
str str

The task type of the model.

Source code in inference/core/managers/base.py
436
437
438
439
440
441
442
443
444
445
446
def get_task_type(self, model_id: str, api_key: str = None) -> str:
    """Retrieves the task type for a given model.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        str: The task type of the model.
    """
    model = self._get_model_reference(model_id=model_id)
    return model.task_type
infer_from_request async
infer_from_request(model_id, request, **kwargs)

Runs inference on the specified model with the given request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/base.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
async def infer_from_request(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Runs inference on the specified model with the given request.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    logger.debug(
        f"ModelManager - inference from request started for model_id={model_id}."
    )
    enable_model_monitoring = not getattr(
        request, "disable_model_monitoring", False
    )
    if METRICS_ENABLED and self.pingback and enable_model_monitoring:
        logger.debug("ModelManager - setting pingback fallback api key...")
        self.pingback.fallback_api_key = request.api_key
    try:
        rtn_val = await self.model_infer(
            model_id=model_id, request=request, **kwargs
        )
        logger.debug(
            f"ModelManager - inference from request finished for model_id={model_id}."
        )
        finish_time = time.time()
        if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
            try:
                logger.debug(
                    f"ModelManager - caching inference request started for model_id={model_id}"
                )
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                if (
                    hasattr(request, "image")
                    and hasattr(request.image, "type")
                    and request.image.type == "numpy"
                ):
                    request.image.value = str(request.image.value)
                cache.zadd(
                    f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value=to_cachable_inference_item(request, rtn_val),
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                logger.debug(
                    f"ModelManager - caching inference request finished for model_id={model_id}"
                )
            except Exception as cache_error:
                logger.warning(
                    f"Failed to cache inference data for model {model_id}: {cache_error}"
                )
        return rtn_val
    except Exception as e:
        finish_time = time.time()
        if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
            try:
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                cache.zadd(
                    f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value={
                        "request": jsonable_encoder(
                            request.dict(exclude={"image", "subject", "prompt"})
                        ),
                        "error": str(e),
                    },
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
            except Exception as cache_error:
                logger.warning(
                    f"Failed to cache error data for model {model_id}: {cache_error}"
                )
        raise
infer_from_request_sync
infer_from_request_sync(model_id, request, **kwargs)

Runs inference on the specified model with the given request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/base.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def infer_from_request_sync(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Runs inference on the specified model with the given request.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    logger.debug(
        f"ModelManager - inference from request started for model_id={model_id}."
    )
    enable_model_monitoring = not getattr(
        request, "disable_model_monitoring", False
    )
    if METRICS_ENABLED and self.pingback and enable_model_monitoring:
        logger.debug("ModelManager - setting pingback fallback api key...")
        self.pingback.fallback_api_key = request.api_key
    try:
        rtn_val = self.model_infer_sync(
            model_id=model_id, request=request, **kwargs
        )
        logger.debug(
            f"ModelManager - inference from request finished for model_id={model_id}."
        )
        finish_time = time.time()
        if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
            try:
                logger.debug(
                    f"ModelManager - caching inference request started for model_id={model_id}"
                )
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                if (
                    hasattr(request, "image")
                    and hasattr(request.image, "type")
                    and request.image.type == "numpy"
                ):
                    request.image.value = str(request.image.value)
                cache.zadd(
                    f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value=to_cachable_inference_item(request, rtn_val),
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                logger.debug(
                    f"ModelManager - caching inference request finished for model_id={model_id}"
                )
            except Exception as cache_error:
                logger.warning(
                    f"Failed to cache inference data for model {model_id}: {cache_error}"
                )
        return rtn_val
    except Exception as e:
        finish_time = time.time()
        if not DISABLE_INFERENCE_CACHE and enable_model_monitoring:
            try:
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                cache.zadd(
                    f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value={
                        "request": jsonable_encoder(
                            request.dict(exclude={"image", "subject", "prompt"})
                        ),
                        "error": str(e),
                    },
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
            except Exception as cache_error:
                logger.warning(
                    f"Failed to cache error data for model {model_id}: {cache_error}"
                )
        raise
init_pingback
init_pingback()

Initializes pingback mechanism.

Source code in inference/core/managers/base.py
54
55
56
57
58
59
60
def init_pingback(self):
    """Initializes pingback mechanism."""
    self.num_errors = 0  # in the device
    self.uuid = ROBOFLOW_SERVER_UUID
    if METRICS_ENABLED:
        self.pingback = PingbackInfo(self)
        self.pingback.start()
keys
keys()

Retrieve the keys (model identifiers) from the manager.

Returns:

Type Description

List[str]: The keys of the models in the manager.

Source code in inference/core/managers/base.py
517
518
519
520
521
522
523
def keys(self):
    """Retrieve the keys (model identifiers) from the manager.

    Returns:
        List[str]: The keys of the models in the manager.
    """
    return self._models.keys()
make_response
make_response(model_id, predictions, *args, **kwargs)

Creates a response object from the model's predictions.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
predictions List[List[float]]

The model's predictions.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The created response object.

Source code in inference/core/managers/base.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def make_response(
    self, model_id: str, predictions: List[List[float]], *args, **kwargs
) -> InferenceResponse:
    """Creates a response object from the model's predictions.

    Args:
        model_id (str): The identifier of the model.
        predictions (List[List[float]]): The model's predictions.

    Returns:
        InferenceResponse: The created response object.
    """
    model = self._get_model_reference(model_id=model_id)
    return model.make_response(predictions, *args, **kwargs)
models
models()

Retrieve the models dictionary from the manager.

Returns:

Type Description
Dict[str, Model]

Dict[str, Model]: The keys of the models in the manager.

Source code in inference/core/managers/base.py
525
526
527
528
529
530
531
def models(self) -> Dict[str, Model]:
    """Retrieve the models dictionary from the manager.

    Returns:
        Dict[str, Model]: The keys of the models in the manager.
    """
    return self._models
postprocess
postprocess(
    model_id,
    predictions,
    preprocess_return_metadata,
    *args,
    **kwargs
)

Processes the model's predictions after inference.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
predictions ndarray

The model's predictions.

required

Returns:

Type Description
List[List[float]]

List[List[float]]: The post-processed predictions.

Source code in inference/core/managers/base.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def postprocess(
    self,
    model_id: str,
    predictions: Tuple[np.ndarray, ...],
    preprocess_return_metadata: PreprocessReturnMetadata,
    *args,
    **kwargs,
) -> List[List[float]]:
    """Processes the model's predictions after inference.

    Args:
        model_id (str): The identifier of the model.
        predictions (np.ndarray): The model's predictions.

    Returns:
        List[List[float]]: The post-processed predictions.
    """
    model = self._get_model_reference(model_id=model_id)
    return model.postprocess(
        predictions, preprocess_return_metadata, *args, **kwargs
    )
predict
predict(model_id, *args, **kwargs)

Runs prediction on the specified model.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Type Description
Tuple[ndarray, ...]

np.ndarray: The predictions from the model.

Source code in inference/core/managers/base.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
    """Runs prediction on the specified model.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        np.ndarray: The predictions from the model.
    """
    model = self._get_model_reference(model_id=model_id)
    model.metrics["num_inferences"] += 1
    tic = time.perf_counter()
    res = model.predict(*args, **kwargs)
    toc = time.perf_counter()
    model.metrics["avg_inference_time"] += toc - tic
    return res
preprocess
preprocess(model_id, request)

Preprocesses the request before inference.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to preprocess.

required

Returns:

Type Description
Tuple[ndarray, PreprocessReturnMetadata]

Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.

Source code in inference/core/managers/base.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def preprocess(
    self, model_id: str, request: InferenceRequest
) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
    """Preprocesses the request before inference.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to preprocess.

    Returns:
        Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.
    """
    model = self._get_model_reference(model_id=model_id)
    return model.preprocess(**request.dict())
remove
remove(model_id, delete_from_disk=True)

Removes a model from the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
Source code in inference/core/managers/base.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def remove(self, model_id: str, delete_from_disk: bool = True) -> None:
    """Removes a model from the manager.

    Args:
        model_id (str): The identifier of the model.
    """
    try:
        logger.debug(f"Removing model {model_id} from base model manager")
        model_lock = self._get_lock_for_a_model(model_id=model_id)
        with acquire_with_timeout(lock=model_lock) as acquired:
            if not acquired:
                raise ModelManagerLockAcquisitionError(
                    f"Could not acquire lock for model with id={model_id}."
                )
            if model_id not in self._models:
                return None
            self._models[model_id].clear_cache(delete_from_disk=delete_from_disk)
            del self._models[model_id]
            self._dispose_model_lock(model_id=model_id)
            try_releasing_cuda_memory()
    except InferenceModelNotFound:
        logger.warning(
            f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..."
        )

inference.core.managers.metrics

Functions

get_container_stats

get_container_stats(docker_socket_path)

Gets the container stats. Returns: dict: A dictionary containing the container stats.

Source code in inference/core/managers/metrics.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def get_container_stats(docker_socket_path: str) -> dict:
    """
    Gets the container stats.
    Returns:
        dict: A dictionary containing the container stats.
    """

    try:
        container_id = socket.gethostname()
        connection = http.client.HTTPConnection("localhost")
        connection.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        connection.sock.connect(docker_socket_path)
        connection.request(
            "GET",
            f"/containers/{container_id}/stats?stream=false",
            headers={"Host": "localhost"},
        )
        response = connection.getresponse()
        data = response.read()
        connection.close()
        if response.status != 200:
            raise Exception(data.decode())
        stats = json.loads(data.decode())
        return {"stats": stats}
    except Exception as e:
        logger.exception(e)
        raise Exception("An error occurred while fetching container stats.")

get_model_metrics

get_model_metrics(
    inference_server_id, model_id, min=-1, max=float("inf")
)

Gets the metrics for a given model between a specified time range.

Parameters:

Name Type Description Default
device_id str

The identifier of the device.

required
model_id str

The identifier of the model.

required
start float

The starting timestamp of the time range. Defaults to -1.

required
stop float

The ending timestamp of the time range. Defaults to float("inf").

required

Returns:

Name Type Description
dict dict

A dictionary containing the metrics of the model: - num_inferences (int): The number of inferences made. - avg_inference_time (float): The average inference time. - num_errors (int): The number of errors occurred.

Source code in inference/core/managers/metrics.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def get_model_metrics(
    inference_server_id: str, model_id: str, min: float = -1, max: float = float("inf")
) -> dict:
    """
    Gets the metrics for a given model between a specified time range.

    Args:
        device_id (str): The identifier of the device.
        model_id (str): The identifier of the model.
        start (float, optional): The starting timestamp of the time range. Defaults to -1.
        stop (float, optional): The ending timestamp of the time range. Defaults to float("inf").

    Returns:
        dict: A dictionary containing the metrics of the model:
              - num_inferences (int): The number of inferences made.
              - avg_inference_time (float): The average inference time.
              - num_errors (int): The number of errors occurred.
    """
    now = time.time()
    inferences_with_times = cache.zrangebyscore(
        f"inference:{inference_server_id}:{model_id}", min=min, max=max, withscores=True
    )
    num_inferences = len(inferences_with_times)
    inference_times = []
    for inference, t in inferences_with_times:
        response = inference["response"]
        if isinstance(response, list):
            times = [r["time"] for r in response if "time" in r]
            inference_times.extend(times)
        else:
            if "time" in response:
                inference_times.append(response["time"])
    avg_inference_time = (
        sum(inference_times) / len(inference_times) if len(inference_times) > 0 else 0
    )
    errors_with_times = cache.zrangebyscore(
        f"error:{inference_server_id}:{model_id}", min=min, max=max, withscores=True
    )
    num_errors = len(errors_with_times)
    return {
        "num_inferences": num_inferences,
        "avg_inference_time": avg_inference_time,
        "num_errors": num_errors,
    }

get_system_info

get_system_info()

Collects system information such as platform, architecture, hostname, IP address, MAC address, and processor details.

Returns:

Name Type Description
dict dict

A dictionary containing detailed system information.

Source code in inference/core/managers/metrics.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_system_info() -> dict:
    """Collects system information such as platform, architecture, hostname, IP address, MAC address, and processor details.

    Returns:
        dict: A dictionary containing detailed system information.
    """
    info = {}
    try:
        info["platform"] = platform.system()
        info["platform_release"] = platform.release()
        info["platform_version"] = platform.version()
        info["architecture"] = platform.machine()
        info["hostname"] = socket.gethostname()
        info["ip_address"] = socket.gethostbyname(socket.gethostname())
        info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode()))
        info["processor"] = platform.processor()
    except Exception as e:
        logger.exception(e)
    finally:
        return info

inference.core.managers.model_load_collector

Classes

ModelLoadCollector

Thread-safe collector for model cold start events during a request.

A single instance is shared across all threads handling a single request. Each entry stores a model_id alongside the load time.

Mirrors the design of RemoteProcessingTimeCollector from inference_sdk.

Source code in inference/core/managers/model_load_collector.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ModelLoadCollector:
    """Thread-safe collector for model cold start events during a request.

    A single instance is shared across all threads handling a single request.
    Each entry stores a model_id alongside the load time.

    Mirrors the design of RemoteProcessingTimeCollector from inference_sdk.
    """

    def __init__(self):
        self._entries: list = []  # list of (model_id, load_time) tuples
        self._lock = threading.Lock()

    def record(self, model_id: str, load_time: float) -> None:
        with self._lock:
            self._entries.append((model_id, load_time))

    def has_data(self) -> bool:
        with self._lock:
            return len(self._entries) > 0

    def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]:
        """Return (total_load_time, entries_json_or_none).

        Returns the total model load time and a JSON string of individual
        entries.  If the JSON exceeds *max_detail_bytes*, the detail string
        is omitted (None).
        """
        with self._lock:
            entries = list(self._entries)
        total = sum(t for _, t in entries)
        detail = json.dumps([{"m": m, "t": t} for m, t in entries])
        if len(detail) > max_detail_bytes:
            detail = None
        return total, detail
Functions
summarize
summarize(max_detail_bytes=4096)

Return (total_load_time, entries_json_or_none).

Returns the total model load time and a JSON string of individual entries. If the JSON exceeds max_detail_bytes, the detail string is omitted (None).

Source code in inference/core/managers/model_load_collector.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]:
    """Return (total_load_time, entries_json_or_none).

    Returns the total model load time and a JSON string of individual
    entries.  If the JSON exceeds *max_detail_bytes*, the detail string
    is omitted (None).
    """
    with self._lock:
        entries = list(self._entries)
    total = sum(t for _, t in entries)
    detail = json.dumps([{"m": m, "t": t} for m, t in entries])
    if len(detail) > max_detail_bytes:
        detail = None
    return total, detail

RequestModelIds

Thread-safe set of model IDs used during a request.

Source code in inference/core/managers/model_load_collector.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class RequestModelIds:
    """Thread-safe set of model IDs used during a request."""

    def __init__(self):
        self._ids: set = set()
        self._lock = threading.Lock()

    def add(self, model_id: str) -> None:
        with self._lock:
            self._ids.add(model_id)

    def get_ids(self) -> set:
        with self._lock:
            return set(self._ids)

inference.core.managers.pingback

Classes

PingbackInfo

Class responsible for managing pingback information for Roboflow.

This class initializes a scheduler to periodically post data to Roboflow, containing information about the models, container, and device.

Attributes:

Name Type Description
scheduler BackgroundScheduler

A scheduler for running jobs in the background.

model_manager ModelManager

Reference to the model manager object.

process_startup_time str

Unix timestamp indicating when the process started.

METRICS_URL str

URL to send the pingback data to.

system_info dict

Information about the system.

window_start_timestamp str

Unix timestamp indicating the start of the current window.

Source code in inference/core/managers/pingback.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class PingbackInfo:
    """Class responsible for managing pingback information for Roboflow.

    This class initializes a scheduler to periodically post data to Roboflow, containing information about the models,
    container, and device.

    Attributes:
        scheduler (BackgroundScheduler): A scheduler for running jobs in the background.
        model_manager (ModelManager): Reference to the model manager object.
        process_startup_time (str): Unix timestamp indicating when the process started.
        METRICS_URL (str): URL to send the pingback data to.
        system_info (dict): Information about the system.
        window_start_timestamp (str): Unix timestamp indicating the start of the current window.
    """

    def __init__(self, manager):
        """Initializes PingbackInfo with the given manager.

        Args:
            manager (ModelManager): Reference to the model manager object.
        """
        try:
            self.scheduler = BackgroundScheduler(
                job_defaults={"coalesce": True, "max_instances": 1}
            )
            self.model_manager = manager
            self.process_startup_time = str(int(time.time()))
            logger.debug(
                "UUID: " + self.model_manager.uuid
            )  # To correlate with UI container view
            self.window_start_timestamp = str(int(time.time()))
            context = {
                "api_key": API_KEY,
                "timestamp": str(int(time.time())),
                "device_id": GLOBAL_DEVICE_ID,
                "inference_server_id": GLOBAL_INFERENCE_SERVER_ID,
                "inference_server_version": __version__,
                "tags": TAGS,
            }
            self.environment_info = context | get_system_info()

            # we will set this from model manager when a new api key is used
            # to use in case there is no global ENV api key configured
            self.fallback_api_key = None

        except Exception as e:
            logger.debug(
                "Error sending pingback to Roboflow, if you want to disable this feature unset the ROBOFLOW_ENABLED environment variable. "
                + str(e)
            )

    def start(self):
        """Starts the scheduler to periodically post data to Roboflow.

        If METRICS_ENABLED is False, a warning is logged, and the method returns without starting the scheduler.
        """
        if METRICS_ENABLED == False:
            logger.warning(
                "Metrics reporting to Roboflow is disabled; not sending back stats to Roboflow."
            )
            return
        try:
            self.scheduler.add_job(
                self.post_data,
                "interval",
                seconds=METRICS_INTERVAL,
                args=[self.model_manager],
                replace_existing=True,
            )
            self.scheduler.start()
        except Exception as e:
            logger.debug(e)

    def stop(self):
        """Stops the scheduler."""
        self.scheduler.shutdown()

    def post_data(self, model_manager):
        """Posts data to Roboflow about the models, container, device, and other relevant metrics.

        Args:
            model_manager (ModelManager): Reference to the model manager object.

        The data is collected and reset for the next window, and a POST request is made to the pingback URL.
        """
        all_data = self.environment_info.copy()
        all_data["inference_results"] = []

        # use fallback api key if env didn't have one
        if self.fallback_api_key and not all_data.get("api_key"):
            all_data["api_key"] = self.fallback_api_key

        try:
            now = time.time()
            start = now - METRICS_INTERVAL
            for model_id in model_manager.models():
                results = get_inference_results_for_model(
                    GLOBAL_INFERENCE_SERVER_ID, model_id, min=start, max=now
                )
                all_data["inference_results"] = all_data["inference_results"] + results
            res = requests.post(
                wrap_url(METRICS_URL),
                json=all_data,
                timeout=10,
                verify=ROBOFLOW_API_VERIFY_SSL,
            )
            try:
                api_key_safe_raise_for_status(response=res)
                logger.debug(
                    "Sent metrics to Roboflow {} at {}.".format(
                        METRICS_URL, str(all_data)
                    )
                )
            except Exception as e:
                logger.debug(
                    f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable."
                )

        except Exception as e:
            try:
                logger.exception(
                    f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}. Data was: {all_data}"
                )

            except Exception as e2:
                logger.debug(
                    f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}."
                )
Functions
__init__
__init__(manager)

Initializes PingbackInfo with the given manager.

Parameters:

Name Type Description Default
manager ModelManager

Reference to the model manager object.

required
Source code in inference/core/managers/pingback.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(self, manager):
    """Initializes PingbackInfo with the given manager.

    Args:
        manager (ModelManager): Reference to the model manager object.
    """
    try:
        self.scheduler = BackgroundScheduler(
            job_defaults={"coalesce": True, "max_instances": 1}
        )
        self.model_manager = manager
        self.process_startup_time = str(int(time.time()))
        logger.debug(
            "UUID: " + self.model_manager.uuid
        )  # To correlate with UI container view
        self.window_start_timestamp = str(int(time.time()))
        context = {
            "api_key": API_KEY,
            "timestamp": str(int(time.time())),
            "device_id": GLOBAL_DEVICE_ID,
            "inference_server_id": GLOBAL_INFERENCE_SERVER_ID,
            "inference_server_version": __version__,
            "tags": TAGS,
        }
        self.environment_info = context | get_system_info()

        # we will set this from model manager when a new api key is used
        # to use in case there is no global ENV api key configured
        self.fallback_api_key = None

    except Exception as e:
        logger.debug(
            "Error sending pingback to Roboflow, if you want to disable this feature unset the ROBOFLOW_ENABLED environment variable. "
            + str(e)
        )
post_data
post_data(model_manager)

Posts data to Roboflow about the models, container, device, and other relevant metrics.

Parameters:

Name Type Description Default
model_manager ModelManager

Reference to the model manager object.

required

The data is collected and reset for the next window, and a POST request is made to the pingback URL.

Source code in inference/core/managers/pingback.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def post_data(self, model_manager):
    """Posts data to Roboflow about the models, container, device, and other relevant metrics.

    Args:
        model_manager (ModelManager): Reference to the model manager object.

    The data is collected and reset for the next window, and a POST request is made to the pingback URL.
    """
    all_data = self.environment_info.copy()
    all_data["inference_results"] = []

    # use fallback api key if env didn't have one
    if self.fallback_api_key and not all_data.get("api_key"):
        all_data["api_key"] = self.fallback_api_key

    try:
        now = time.time()
        start = now - METRICS_INTERVAL
        for model_id in model_manager.models():
            results = get_inference_results_for_model(
                GLOBAL_INFERENCE_SERVER_ID, model_id, min=start, max=now
            )
            all_data["inference_results"] = all_data["inference_results"] + results
        res = requests.post(
            wrap_url(METRICS_URL),
            json=all_data,
            timeout=10,
            verify=ROBOFLOW_API_VERIFY_SSL,
        )
        try:
            api_key_safe_raise_for_status(response=res)
            logger.debug(
                "Sent metrics to Roboflow {} at {}.".format(
                    METRICS_URL, str(all_data)
                )
            )
        except Exception as e:
            logger.debug(
                f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable."
            )

    except Exception as e:
        try:
            logger.exception(
                f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}. Data was: {all_data}"
            )

        except Exception as e2:
            logger.debug(
                f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}."
            )
start
start()

Starts the scheduler to periodically post data to Roboflow.

If METRICS_ENABLED is False, a warning is logged, and the method returns without starting the scheduler.

Source code in inference/core/managers/pingback.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def start(self):
    """Starts the scheduler to periodically post data to Roboflow.

    If METRICS_ENABLED is False, a warning is logged, and the method returns without starting the scheduler.
    """
    if METRICS_ENABLED == False:
        logger.warning(
            "Metrics reporting to Roboflow is disabled; not sending back stats to Roboflow."
        )
        return
    try:
        self.scheduler.add_job(
            self.post_data,
            "interval",
            seconds=METRICS_INTERVAL,
            args=[self.model_manager],
            replace_existing=True,
        )
        self.scheduler.start()
    except Exception as e:
        logger.debug(e)
stop
stop()

Stops the scheduler.

Source code in inference/core/managers/pingback.py
 99
100
101
def stop(self):
    """Stops the scheduler."""
    self.scheduler.shutdown()

Functions

inference.core.managers.prometheus

Classes

CustomCollector

Bases: Collector

Source code in inference/core/managers/prometheus.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
class CustomCollector(Collector):
    def __init__(self, model_manager, time_window: int = 10):
        super(CustomCollector, self).__init__()
        self.model_manager = model_manager
        self.time_window = time_window
        self.stream_manager_client = None

    def get_metrics(self, maxModels: int = 25):
        now = time.time()
        start = now - self.time_window
        count = 0
        results = {}
        if self.model_manager is None:
            logger.warning(
                "This inference server type does not support custom Prometheus metrics, skipping."
            )
            return results
        for model_id in self.model_manager.models():
            if count >= maxModels:
                break
            try:
                results[model_id] = get_model_metrics(
                    GLOBAL_INFERENCE_SERVER_ID, model_id, min=start, max=now
                )
            except Exception as e:
                logger.debug(
                    "Error getting metrics for model " + model_id + ": " + str(e)
                )
            count += 1
        return results

    async def _fetch_stream_metrics(self) -> Dict[str, dict]:
        # Pipeline status is fetched via TCP IPC to the stream manager process.
        # Pipelines run in separate subprocesses, so socket-based IPC is required.
        pipelines_response = await self.stream_manager_client.list_pipelines()
        pipeline_ids = pipelines_response.pipelines
        metrics = {}
        for pipeline_id in pipeline_ids:
            status_response = await self.stream_manager_client.get_status(pipeline_id)
            report = status_response.report
            latency_reports = report.get("latency_reports", [])
            sources_metadata = report.get("sources_metadata", [])
            camera_fps = self._average_source_fps(sources_metadata)
            source_label = self._extract_source_label(sources_metadata)
            metrics[pipeline_id] = {
                "inference_throughput": report.get("inference_throughput", 0.0),
                "camera_fps": camera_fps,
                "frame_decoding_latency": self._average_latency_field(
                    latency_reports, "frame_decoding_latency"
                ),
                "inference_latency": self._average_latency_field(
                    latency_reports, "inference_latency"
                ),
                "e2e_latency": self._average_latency_field(
                    latency_reports, "e2e_latency"
                ),
                "source": source_label,
            }
        return metrics

    def get_stream_metrics(self) -> Dict[str, dict]:
        if self.stream_manager_client is None:
            return {}
        try:
            try:
                return asyncio.run(self._fetch_stream_metrics())
            except RuntimeError:
                with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
                    return pool.submit(
                        asyncio.run, self._fetch_stream_metrics()
                    ).result()
        except Exception:
            logger.debug("Failed to fetch stream metrics", exc_info=True)
            return {}

    @staticmethod
    def _average_latency_field(latency_reports: List[dict], field: str) -> float:
        values = [r[field] for r in latency_reports if r.get(field) is not None]
        if not values:
            return 0.0
        return sum(values) / len(values)

    @staticmethod
    def _average_source_fps(sources_metadata: List[dict]) -> float:
        values = []
        for src in sources_metadata:
            props = src.get("source_properties") or {}
            fps = props.get("fps")
            if fps is not None and fps > 0:
                values.append(fps)
        if not values:
            return 0.0
        return sum(values) / len(values)

    @staticmethod
    def _sanitize_source_reference(ref: str) -> str:
        """Strip credentials and query parameters from URLs to avoid leaking
        secrets in metrics."""
        parsed = urlparse(ref)
        if parsed.scheme and parsed.hostname:
            netloc = parsed.hostname + (f":{parsed.port}" if parsed.port else "")
            sanitized = parsed._replace(netloc=netloc, query="", fragment="")
            return urlunparse(sanitized)
        return ref

    @staticmethod
    def _extract_source_label(sources_metadata: List[dict]) -> str:
        if not METRICS_INCLUDE_SOURCE_LABELS:
            return ""
        refs = []
        for src in sources_metadata:
            ref = src.get("source_reference")
            if ref is not None:
                refs.append(CustomCollector._sanitize_source_reference(str(ref)))
        return ",".join(refs) if refs else ""

    def sanitize_string(self, input_string):
        sanitized_string = re.sub(r"[^a-zA-Z0-9_]", "_", input_string)
        return sanitized_string

    def collect(self):
        results = self.get_metrics()
        num_inferences_total = 0
        num_errors_total = 0
        avg_inference_time_total = 0
        for model_id, metrics in results.items():
            sane_model_id = self.sanitize_string(model_id)
            yield GaugeMetricFamily(
                f"num_inferences_{sane_model_id}",
                f"Number of inferences made in {self.time_window}s",
                value=metrics["num_inferences"],
            )
            yield GaugeMetricFamily(
                f"avg_inference_time_{sane_model_id}",
                f"Average inference time (over inferences completed in {self.time_window}s) to infer this model",
                value=metrics["avg_inference_time"],
            )
            yield GaugeMetricFamily(
                f"num_errors_{sane_model_id}",
                f"Number of errors in {self.time_window}s",
                value=metrics["num_errors"],
            )
            num_inferences_total += metrics["num_inferences"]
            num_errors_total += metrics["num_errors"]
            avg_inference_time_total += metrics["avg_inference_time"]
        yield GaugeMetricFamily(
            "num_inferences_total",
            f"Total number of inferences made in {self.time_window}s",
            value=num_inferences_total,
        )
        yield GaugeMetricFamily(
            "avg_inference_time_total",
            f"Average inference time (over inferences completed in {self.time_window}s) to infer all models.",
            value=avg_inference_time_total,
        )
        yield GaugeMetricFamily(
            "num_errors_total",
            f"Total number of errors in {self.time_window}s",
            value=num_errors_total,
        )

        stream_metrics = self.get_stream_metrics()
        pipeline_labels = ["pipeline_id", "source"]
        inference_fps = GaugeMetricFamily(
            "inference_pipeline_inference_fps",
            "Inference throughput FPS",
            labels=pipeline_labels,
        )
        camera_fps = GaugeMetricFamily(
            "inference_pipeline_camera_fps",
            "Camera source FPS",
            labels=pipeline_labels,
        )
        frame_decoding_latency = GaugeMetricFamily(
            "inference_pipeline_frame_decoding_latency",
            "Average frame decoding latency (seconds)",
            labels=pipeline_labels,
        )
        inference_latency = GaugeMetricFamily(
            "inference_pipeline_inference_latency",
            "Average inference latency (seconds)",
            labels=pipeline_labels,
        )
        e2e_latency = GaugeMetricFamily(
            "inference_pipeline_e2e_latency",
            "Average end-to-end latency (seconds)",
            labels=pipeline_labels,
        )
        for pipeline_id, pm in stream_metrics.items():
            label_values = [pipeline_id, pm["source"]]
            inference_fps.add_metric(label_values, pm["inference_throughput"])
            camera_fps.add_metric(label_values, pm["camera_fps"])
            frame_decoding_latency.add_metric(
                label_values, pm["frame_decoding_latency"]
            )
            inference_latency.add_metric(label_values, pm["inference_latency"])
            e2e_latency.add_metric(label_values, pm["e2e_latency"])
        yield inference_fps
        yield camera_fps
        yield frame_decoding_latency
        yield inference_latency
        yield e2e_latency
        yield GaugeMetricFamily(
            "inference_pipeline_active_streams",
            "Number of active inference pipelines",
            value=len(stream_metrics),
        )

InferenceInstrumentator

Class responsible for managing the Prometheus metrics for the inference server.

This class inititalizes the Prometheus Instrumentator and exposes the metrics endpoint.

Source code in inference/core/managers/prometheus.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class InferenceInstrumentator:
    """
    Class responsible for managing the Prometheus metrics for the inference server.

    This class inititalizes the Prometheus Instrumentator and exposes the metrics endpoint.

    """

    def __init__(self, app, model_manager, endpoint: str = "/metrics"):
        self.instrumentator = Instrumentator()
        self.instrumentator.instrument(app).expose(app, endpoint)
        self.collector = CustomCollector(model_manager)
        REGISTRY.register(self.collector)

    def set_stream_manager_client(self, stream_manager_client) -> None:
        self.collector.stream_manager_client = stream_manager_client

Functions

core/managers/decorators

inference.core.managers.decorators.base

Classes

ModelManagerDecorator

Bases: ModelManager

Basic decorator, it acts like a ModelManager and contains a ModelManager.

Parameters:

Name Type Description Default
model_manager ModelManager

Instance of a ModelManager.

required

Methods:

Name Description
add_model

Adds a model to the manager.

infer

Processes a complete inference request.

infer_only

Performs only the inference part of a request.

preprocess

Processes the preprocessing part of a request.

get_task_type

Gets the task type associated with a model.

get_class_names

Gets the class names for a given model.

remove

Removes a model from the manager.

__len__

Returns the number of models in the manager.

__getitem__

Retrieves a model by its ID.

__contains__

Checks if a model exists in the manager.

keys

Returns the keys (model IDs) from the manager.

Source code in inference/core/managers/decorators/base.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class ModelManagerDecorator(ModelManager):
    """Basic decorator, it acts like a `ModelManager` and contains a `ModelManager`.

    Args:
        model_manager (ModelManager): Instance of a ModelManager.

    Methods:
        add_model: Adds a model to the manager.
        infer: Processes a complete inference request.
        infer_only: Performs only the inference part of a request.
        preprocess: Processes the preprocessing part of a request.
        get_task_type: Gets the task type associated with a model.
        get_class_names: Gets the class names for a given model.
        remove: Removes a model from the manager.
        __len__: Returns the number of models in the manager.
        __getitem__: Retrieves a model by its ID.
        __contains__: Checks if a model exists in the manager.
        keys: Returns the keys (model IDs) from the manager.
    """

    @property
    def _models(self):
        raise ValueError("Should only be accessing self.model_manager._models")

    @property
    def model_registry(self):
        raise ValueError("Should only be accessing self.model_manager.model_registry")

    def __init__(self, model_manager: ModelManager):
        """Initializes the decorator with an instance of a ModelManager."""
        self.model_manager = model_manager

    def init_pingback(self):
        self.model_manager.init_pingback()

    @property
    def pingback(self):
        return self.model_manager.pingback

    def add_model(
        self,
        model_id: str,
        api_key: str,
        model_id_alias: Optional[str] = None,
        endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ):
        """Adds a model to the manager.

        Args:
            model_id (str): The identifier of the model.
            model (Model): The model instance.
            endpoint_type (ModelEndpointType, optional): The endpoint type to use for the model.
        """
        if model_id in self:
            return
        self.model_manager.add_model(
            model_id,
            api_key,
            model_id_alias=model_id_alias,
            endpoint_type=endpoint_type,
            countinference=countinference,
            service_secret=service_secret,
        )

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Processes a complete inference request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        return await self.model_manager.infer_from_request(model_id, request, **kwargs)

    def infer_from_request_sync(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Processes a complete inference request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        return self.model_manager.infer_from_request_sync(model_id, request, **kwargs)

    def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None):
        """Performs only the inference part of a request.

        Args:
            model_id (str): The identifier of the model.
            request: The request to process.
            img_in: Input image.
            img_dims: Image dimensions.
            batch_size (int, optional): Batch size.

        Returns:
            Response from the inference-only operation.
        """
        return self.model_manager.infer_only(
            model_id, request, img_in, img_dims, batch_size
        )

    def preprocess(self, model_id: str, request: InferenceRequest):
        """Processes the preprocessing part of a request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to preprocess.
        """
        return self.model_manager.preprocess(model_id, request)

    def get_task_type(self, model_id: str, api_key: str = None) -> str:
        """Gets the task type associated with a model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            str: The task type.
        """
        if api_key is None:
            api_key = API_KEY
        return self.model_manager.get_task_type(model_id, api_key=api_key)

    def get_class_names(self, model_id):
        """Gets the class names for a given model.

        Args:
            model_id: The identifier of the model.

        Returns:
            List of class names.
        """
        return self.model_manager.get_class_names(model_id)

    def remove(self, model_id: str, delete_from_disk: bool = True) -> Model:
        """Removes a model from the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            Model: The removed model.
        """
        return self.model_manager.remove(model_id, delete_from_disk=delete_from_disk)

    def __len__(self) -> int:
        """Returns the number of models in the manager.

        Returns:
            int: Number of models.
        """
        return len(self.model_manager)

    def __getitem__(self, key: str) -> Model:
        """Retrieves a model by its ID.

        Args:
            key (str): The identifier of the model.

        Returns:
            Model: The model instance.
        """
        return self.model_manager[key]

    def __contains__(self, model_id: str):
        """Checks if a model exists in the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            bool: True if the model exists, False otherwise.
        """
        return model_id in self.model_manager

    def keys(self):
        """Returns the keys (model IDs) from the manager.

        Returns:
            List of keys (model IDs).
        """
        return self.model_manager.keys()

    def models(self):
        return self.model_manager.models()

    def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
        return self.model_manager.predict(model_id, *args, **kwargs)

    def postprocess(
        self,
        model_id: str,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        *args,
        **kwargs
    ) -> List[List[float]]:
        return self.model_manager.postprocess(
            model_id, predictions, preprocess_return_metadata, *args, **kwargs
        )

    def make_response(
        self, model_id: str, predictions: List[List[float]], *args, **kwargs
    ) -> InferenceResponse:
        return self.model_manager.make_response(model_id, predictions, *args, **kwargs)

    @property
    def num_errors(self):
        return self.model_manager.num_errors

    @num_errors.setter
    def num_errors(self, value):
        self.model_manager.num_errors = value
Functions
__contains__
__contains__(model_id)

Checks if a model exists in the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Name Type Description
bool

True if the model exists, False otherwise.

Source code in inference/core/managers/decorators/base.py
188
189
190
191
192
193
194
195
196
197
def __contains__(self, model_id: str):
    """Checks if a model exists in the manager.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        bool: True if the model exists, False otherwise.
    """
    return model_id in self.model_manager
__getitem__
__getitem__(key)

Retrieves a model by its ID.

Parameters:

Name Type Description Default
key str

The identifier of the model.

required

Returns:

Name Type Description
Model Model

The model instance.

Source code in inference/core/managers/decorators/base.py
177
178
179
180
181
182
183
184
185
186
def __getitem__(self, key: str) -> Model:
    """Retrieves a model by its ID.

    Args:
        key (str): The identifier of the model.

    Returns:
        Model: The model instance.
    """
    return self.model_manager[key]
__init__
__init__(model_manager)

Initializes the decorator with an instance of a ModelManager.

Source code in inference/core/managers/decorators/base.py
42
43
44
def __init__(self, model_manager: ModelManager):
    """Initializes the decorator with an instance of a ModelManager."""
    self.model_manager = model_manager
__len__
__len__()

Returns the number of models in the manager.

Returns:

Name Type Description
int int

Number of models.

Source code in inference/core/managers/decorators/base.py
169
170
171
172
173
174
175
def __len__(self) -> int:
    """Returns the number of models in the manager.

    Returns:
        int: Number of models.
    """
    return len(self.model_manager)
add_model
add_model(
    model_id,
    api_key,
    model_id_alias=None,
    endpoint_type=ModelEndpointType.ORT,
    countinference=None,
    service_secret=None,
)

Adds a model to the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
model Model

The model instance.

required
endpoint_type ModelEndpointType

The endpoint type to use for the model.

ORT
Source code in inference/core/managers/decorators/base.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def add_model(
    self,
    model_id: str,
    api_key: str,
    model_id_alias: Optional[str] = None,
    endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
):
    """Adds a model to the manager.

    Args:
        model_id (str): The identifier of the model.
        model (Model): The model instance.
        endpoint_type (ModelEndpointType, optional): The endpoint type to use for the model.
    """
    if model_id in self:
        return
    self.model_manager.add_model(
        model_id,
        api_key,
        model_id_alias=model_id_alias,
        endpoint_type=endpoint_type,
        countinference=countinference,
        service_secret=service_secret,
    )
get_class_names
get_class_names(model_id)

Gets the class names for a given model.

Parameters:

Name Type Description Default
model_id

The identifier of the model.

required

Returns:

Type Description

List of class names.

Source code in inference/core/managers/decorators/base.py
147
148
149
150
151
152
153
154
155
156
def get_class_names(self, model_id):
    """Gets the class names for a given model.

    Args:
        model_id: The identifier of the model.

    Returns:
        List of class names.
    """
    return self.model_manager.get_class_names(model_id)
get_task_type
get_task_type(model_id, api_key=None)

Gets the task type associated with a model.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Name Type Description
str str

The task type.

Source code in inference/core/managers/decorators/base.py
134
135
136
137
138
139
140
141
142
143
144
145
def get_task_type(self, model_id: str, api_key: str = None) -> str:
    """Gets the task type associated with a model.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        str: The task type.
    """
    if api_key is None:
        api_key = API_KEY
    return self.model_manager.get_task_type(model_id, api_key=api_key)
infer_from_request async
infer_from_request(model_id, request, **kwargs)

Processes a complete inference request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/decorators/base.py
80
81
82
83
84
85
86
87
88
89
90
91
92
async def infer_from_request(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Processes a complete inference request.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    return await self.model_manager.infer_from_request(model_id, request, **kwargs)
infer_from_request_sync
infer_from_request_sync(model_id, request, **kwargs)

Processes a complete inference request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/decorators/base.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def infer_from_request_sync(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Processes a complete inference request.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    return self.model_manager.infer_from_request_sync(model_id, request, **kwargs)
infer_only
infer_only(
    model_id, request, img_in, img_dims, batch_size=None
)

Performs only the inference part of a request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request

The request to process.

required
img_in

Input image.

required
img_dims

Image dimensions.

required
batch_size int

Batch size.

None

Returns:

Type Description

Response from the inference-only operation.

Source code in inference/core/managers/decorators/base.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None):
    """Performs only the inference part of a request.

    Args:
        model_id (str): The identifier of the model.
        request: The request to process.
        img_in: Input image.
        img_dims: Image dimensions.
        batch_size (int, optional): Batch size.

    Returns:
        Response from the inference-only operation.
    """
    return self.model_manager.infer_only(
        model_id, request, img_in, img_dims, batch_size
    )
keys
keys()

Returns the keys (model IDs) from the manager.

Returns:

Type Description

List of keys (model IDs).

Source code in inference/core/managers/decorators/base.py
199
200
201
202
203
204
205
def keys(self):
    """Returns the keys (model IDs) from the manager.

    Returns:
        List of keys (model IDs).
    """
    return self.model_manager.keys()
preprocess
preprocess(model_id, request)

Processes the preprocessing part of a request.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to preprocess.

required
Source code in inference/core/managers/decorators/base.py
125
126
127
128
129
130
131
132
def preprocess(self, model_id: str, request: InferenceRequest):
    """Processes the preprocessing part of a request.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to preprocess.
    """
    return self.model_manager.preprocess(model_id, request)
remove
remove(model_id, delete_from_disk=True)

Removes a model from the manager.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required

Returns:

Name Type Description
Model Model

The removed model.

Source code in inference/core/managers/decorators/base.py
158
159
160
161
162
163
164
165
166
167
def remove(self, model_id: str, delete_from_disk: bool = True) -> Model:
    """Removes a model from the manager.

    Args:
        model_id (str): The identifier of the model.

    Returns:
        Model: The removed model.
    """
    return self.model_manager.remove(model_id, delete_from_disk=delete_from_disk)

inference.core.managers.decorators.locked_load

Classes

LockedLoadModelManagerDecorator

Bases: ModelManagerDecorator

Must acquire lock to load model

Source code in inference/core/managers/decorators/locked_load.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class LockedLoadModelManagerDecorator(ModelManagerDecorator):
    """Must acquire lock to load model"""

    def add_model(
        self,
        model_id: str,
        api_key: str,
        model_id_alias=None,
        endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ):
        with cache.lock(lock_str(model_id), expire=180.0):
            return super().add_model(
                model_id,
                api_key,
                model_id_alias=model_id_alias,
                endpoint_type=endpoint_type,
                countinference=countinference,
                service_secret=service_secret,
            )

inference.core.managers.decorators.logger

Classes

WithLogger

Bases: ModelManagerDecorator

Logger Decorator, it logs what's going on inside the manager.

Source code in inference/core/managers/decorators/logger.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class WithLogger(ModelManagerDecorator):
    """Logger Decorator, it logs what's going on inside the manager."""

    def add_model(
        self,
        model_id: str,
        api_key: str,
        model_id_alias: Optional[str] = None,
        endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ):
        """Adds a model to the manager and logs the action.

        Args:
            model_id (str): The identifier of the model.
            model (Model): The model instance.

        Returns:
            The result of the add_model method from the superclass.
        """
        logger.info(f"🤖 {model_id} added.")
        return super().add_model(
            model_id,
            api_key,
            model_id_alias=model_id_alias,
            endpoint_type=endpoint_type,
            countinference=countinference,
            service_secret=service_secret,
        )

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Processes a complete inference request and logs both the request and response.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        logger.info(f"📥 [{model_id}] request={request}.")
        res = await super().infer_from_request(model_id, request, **kwargs)
        logger.info(f"📥 [{model_id}] res={res}.")
        return res

    def infer_from_request_sync(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Processes a complete inference request and logs both the request and response.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        logger.info(f"📥 [{model_id}] request={request}.")
        res = super().infer_from_request_sync(model_id, request, **kwargs)
        logger.info(f"📥 [{model_id}] res={res}.")
        return res

    def remove(self, model_id: str, delete_from_disk: bool = True) -> Model:
        """Removes a model from the manager and logs the action.

        Args:
            model_id (str): The identifier of the model to remove.

        Returns:
            Model: The removed model.
        """
        res = super().remove(model_id)
        logger.info(f"❌ removed {model_id}, delete_from_disk={delete_from_disk}")
        return res
Functions
add_model
add_model(
    model_id,
    api_key,
    model_id_alias=None,
    endpoint_type=ModelEndpointType.ORT,
    countinference=None,
    service_secret=None,
)

Adds a model to the manager and logs the action.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
model Model

The model instance.

required

Returns:

Type Description

The result of the add_model method from the superclass.

Source code in inference/core/managers/decorators/logger.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def add_model(
    self,
    model_id: str,
    api_key: str,
    model_id_alias: Optional[str] = None,
    endpoint_type: ModelEndpointType = ModelEndpointType.ORT,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
):
    """Adds a model to the manager and logs the action.

    Args:
        model_id (str): The identifier of the model.
        model (Model): The model instance.

    Returns:
        The result of the add_model method from the superclass.
    """
    logger.info(f"🤖 {model_id} added.")
    return super().add_model(
        model_id,
        api_key,
        model_id_alias=model_id_alias,
        endpoint_type=endpoint_type,
        countinference=countinference,
        service_secret=service_secret,
    )
infer_from_request async
infer_from_request(model_id, request, **kwargs)

Processes a complete inference request and logs both the request and response.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/decorators/logger.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
async def infer_from_request(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Processes a complete inference request and logs both the request and response.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    logger.info(f"📥 [{model_id}] request={request}.")
    res = await super().infer_from_request(model_id, request, **kwargs)
    logger.info(f"📥 [{model_id}] res={res}.")
    return res
infer_from_request_sync
infer_from_request_sync(model_id, request, **kwargs)

Processes a complete inference request and logs both the request and response.

Parameters:

Name Type Description Default
model_id str

The identifier of the model.

required
request InferenceRequest

The request to process.

required

Returns:

Name Type Description
InferenceResponse InferenceResponse

The response from the inference.

Source code in inference/core/managers/decorators/logger.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def infer_from_request_sync(
    self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
    """Processes a complete inference request and logs both the request and response.

    Args:
        model_id (str): The identifier of the model.
        request (InferenceRequest): The request to process.

    Returns:
        InferenceResponse: The response from the inference.
    """
    logger.info(f"📥 [{model_id}] request={request}.")
    res = super().infer_from_request_sync(model_id, request, **kwargs)
    logger.info(f"📥 [{model_id}] res={res}.")
    return res
remove
remove(model_id, delete_from_disk=True)

Removes a model from the manager and logs the action.

Parameters:

Name Type Description Default
model_id str

The identifier of the model to remove.

required

Returns:

Name Type Description
Model Model

The removed model.

Source code in inference/core/managers/decorators/logger.py
76
77
78
79
80
81
82
83
84
85
86
87
def remove(self, model_id: str, delete_from_disk: bool = True) -> Model:
    """Removes a model from the manager and logs the action.

    Args:
        model_id (str): The identifier of the model to remove.

    Returns:
        Model: The removed model.
    """
    res = super().remove(model_id)
    logger.info(f"❌ removed {model_id}, delete_from_disk={delete_from_disk}")
    return res

core/models

Base model classes and common prediction logic shared across model types.

inference.core.models.base

Classes

BaseInference

General inference class.

This class provides a basic interface for inference tasks.

Source code in inference/core/models/base.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class BaseInference:
    """General inference class.

    This class provides a basic interface for inference tasks.
    """

    @usage_collector("model")
    def infer(self, image: Any, **kwargs) -> Any:
        """Runs inference on given data.
        - image:
            can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        """
        preproc_image, returned_metadata = self.preprocess(image, **kwargs)
        logger.debug(
            f"Preprocessed input shape: {getattr(preproc_image, 'shape', None)}"
        )
        predicted_arrays = self.predict(preproc_image, **kwargs)
        postprocessed = self.postprocess(predicted_arrays, returned_metadata, **kwargs)

        return postprocessed

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        raise NotImplementedError

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
        raise NotImplementedError

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        raise NotImplementedError

    def infer_from_request(
        self, request: InferenceRequest
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Runs inference on a request

        Args:
            request (InferenceRequest): The request object.

        Returns:
            Union[CVInferenceResponse, List[CVInferenceResponse]]: The response object(s).

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError

    def make_response(
        self, *args, **kwargs
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Constructs an object detection response.

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError
Functions
infer
infer(image, **kwargs)

Runs inference on given data. - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Source code in inference/core/models/base.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@usage_collector("model")
def infer(self, image: Any, **kwargs) -> Any:
    """Runs inference on given data.
    - image:
        can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
    """
    preproc_image, returned_metadata = self.preprocess(image, **kwargs)
    logger.debug(
        f"Preprocessed input shape: {getattr(preproc_image, 'shape', None)}"
    )
    predicted_arrays = self.predict(preproc_image, **kwargs)
    postprocessed = self.postprocess(predicted_arrays, returned_metadata, **kwargs)

    return postprocessed
infer_from_request
infer_from_request(request)

Runs inference on a request

Parameters:

Name Type Description Default
request InferenceRequest

The request object.

required

Returns:

Type Description
Union[InferenceResponse, List[InferenceResponse]]

Union[CVInferenceResponse, List[CVInferenceResponse]]: The response object(s).

Raises:

Type Description
NotImplementedError

This method must be implemented by a subclass.

Source code in inference/core/models/base.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def infer_from_request(
    self, request: InferenceRequest
) -> Union[InferenceResponse, List[InferenceResponse]]:
    """Runs inference on a request

    Args:
        request (InferenceRequest): The request object.

    Returns:
        Union[CVInferenceResponse, List[CVInferenceResponse]]: The response object(s).

    Raises:
        NotImplementedError: This method must be implemented by a subclass.
    """
    raise NotImplementedError
make_response
make_response(*args, **kwargs)

Constructs an object detection response.

Raises:

Type Description
NotImplementedError

This method must be implemented by a subclass.

Source code in inference/core/models/base.py
66
67
68
69
70
71
72
73
74
def make_response(
    self, *args, **kwargs
) -> Union[InferenceResponse, List[InferenceResponse]]:
    """Constructs an object detection response.

    Raises:
        NotImplementedError: This method must be implemented by a subclass.
    """
    raise NotImplementedError

Model

Bases: BaseInference

Base Inference Model (Inherits from BaseInference to define the needed methods)

This class provides the foundational methods for inference and logging, and can be extended by specific models.

Methods:

Name Description
log

Print the given message.

clear_cache

Clears any cache if necessary.

Source code in inference/core/models/base.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class Model(BaseInference):
    """Base Inference Model (Inherits from BaseInference to define the needed methods)

    This class provides the foundational methods for inference and logging, and can be extended by specific models.

    Methods:
        log(m): Print the given message.
        clear_cache(): Clears any cache if necessary.
    """

    def log(self, m):
        """Prints the given message.

        Args:
            m (str): The message to print.
        """
        print(m)

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. This method should be implemented in derived classes as needed.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def infer_from_request(
        self,
        request: InferenceRequest,
    ) -> Union[List[InferenceResponse], InferenceResponse]:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        The function can handle both single and multiple image inference requests. Optionally, it also provides
        a visualization of the predictions if requested.

        Args:
            request (InferenceRequest): The request object containing details for inference, such as the image or
                images to process, any classes to filter by, and whether or not to visualize the predictions.

        Returns:
            Union[List[InferenceResponse], InferenceResponse]: A list of response objects if the request contains
            multiple images, or a single response object if the request contains one image. Each response object
            contains details about the segmented instances, the time taken for inference, and optionally, a visualization.

        Examples:
            >>> request = InferenceRequest(image=my_image, visualize_predictions=True)
            >>> response = infer_from_request(request)
            >>> print(response.time)  # Prints the time taken for inference
            0.125
            >>> print(response.visualization)  # Accesses the visualization of the prediction if available

        Notes:
            - The processing time for each response is included within the response itself.
            - If `visualize_predictions` is set to True in the request, a visualization of the prediction
              is also included in the response.
        """
        t1 = perf_counter()
        responses = self.infer(**request.dict(), return_image_dims=False)
        for response in responses:
            response.time = perf_counter() - t1
            logger.debug(f"model infer time: {response.time * 1000.0} ms")
            if request.id:
                response.inference_id = request.id

        if hasattr(request, "visualize_predictions") and request.visualize_predictions:
            for response in responses:
                response.visualization = self.draw_predictions(request, response)

        if not isinstance(request.image, list) and len(responses) > 0:
            responses = responses[0]

        return responses

    def make_response(
        self, *args, **kwargs
    ) -> Union[InferenceResponse, List[InferenceResponse]]:
        """Makes an inference response from the given arguments.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            InferenceResponse: The inference response.
        """
        raise NotImplementedError(self.__class__.__name__ + ".make_response")
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. This method should be implemented in derived classes as needed.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/base.py
 95
 96
 97
 98
 99
100
101
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. This method should be implemented in derived classes as needed.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass
infer_from_request
infer_from_request(request)

Perform inference based on the details provided in the request, and return the associated responses. The function can handle both single and multiple image inference requests. Optionally, it also provides a visualization of the predictions if requested.

Parameters:

Name Type Description Default
request InferenceRequest

The request object containing details for inference, such as the image or images to process, any classes to filter by, and whether or not to visualize the predictions.

required

Returns:

Type Description
Union[List[InferenceResponse], InferenceResponse]

Union[List[InferenceResponse], InferenceResponse]: A list of response objects if the request contains

Union[List[InferenceResponse], InferenceResponse]

multiple images, or a single response object if the request contains one image. Each response object

Union[List[InferenceResponse], InferenceResponse]

contains details about the segmented instances, the time taken for inference, and optionally, a visualization.

Examples:

>>> request = InferenceRequest(image=my_image, visualize_predictions=True)
>>> response = infer_from_request(request)
>>> print(response.time)  # Prints the time taken for inference
0.125
>>> print(response.visualization)  # Accesses the visualization of the prediction if available
Notes
  • The processing time for each response is included within the response itself.
  • If visualize_predictions is set to True in the request, a visualization of the prediction is also included in the response.
Source code in inference/core/models/base.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def infer_from_request(
    self,
    request: InferenceRequest,
) -> Union[List[InferenceResponse], InferenceResponse]:
    """
    Perform inference based on the details provided in the request, and return the associated responses.
    The function can handle both single and multiple image inference requests. Optionally, it also provides
    a visualization of the predictions if requested.

    Args:
        request (InferenceRequest): The request object containing details for inference, such as the image or
            images to process, any classes to filter by, and whether or not to visualize the predictions.

    Returns:
        Union[List[InferenceResponse], InferenceResponse]: A list of response objects if the request contains
        multiple images, or a single response object if the request contains one image. Each response object
        contains details about the segmented instances, the time taken for inference, and optionally, a visualization.

    Examples:
        >>> request = InferenceRequest(image=my_image, visualize_predictions=True)
        >>> response = infer_from_request(request)
        >>> print(response.time)  # Prints the time taken for inference
        0.125
        >>> print(response.visualization)  # Accesses the visualization of the prediction if available

    Notes:
        - The processing time for each response is included within the response itself.
        - If `visualize_predictions` is set to True in the request, a visualization of the prediction
          is also included in the response.
    """
    t1 = perf_counter()
    responses = self.infer(**request.dict(), return_image_dims=False)
    for response in responses:
        response.time = perf_counter() - t1
        logger.debug(f"model infer time: {response.time * 1000.0} ms")
        if request.id:
            response.inference_id = request.id

    if hasattr(request, "visualize_predictions") and request.visualize_predictions:
        for response in responses:
            response.visualization = self.draw_predictions(request, response)

    if not isinstance(request.image, list) and len(responses) > 0:
        responses = responses[0]

    return responses
log
log(m)

Prints the given message.

Parameters:

Name Type Description Default
m str

The message to print.

required
Source code in inference/core/models/base.py
87
88
89
90
91
92
93
def log(self, m):
    """Prints the given message.

    Args:
        m (str): The message to print.
    """
    print(m)
make_response
make_response(*args, **kwargs)

Makes an inference response from the given arguments.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}

Returns:

Name Type Description
InferenceResponse Union[InferenceResponse, List[InferenceResponse]]

The inference response.

Source code in inference/core/models/base.py
150
151
152
153
154
155
156
157
158
159
160
161
162
def make_response(
    self, *args, **kwargs
) -> Union[InferenceResponse, List[InferenceResponse]]:
    """Makes an inference response from the given arguments.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        InferenceResponse: The inference response.
    """
    raise NotImplementedError(self.__class__.__name__ + ".make_response")

inference.core.models.classification_base

Classes

ClassificationBaseOnnxRoboflowInferenceModel

Bases: OnnxRoboflowInferenceModel

Base class for ONNX models for Roboflow classification inference.

Attributes:

Name Type Description
multiclass bool

Whether the classification is multi-class or not.

Methods:

Name Description
get_infer_bucket_file_list

Get the list of required files for inference.

softmax

Compute softmax values for a given set of scores.

infer

ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response.

draw_predictions

Draw prediction visuals on an image.

Source code in inference/core/models/classification_base.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel):
    """Base class for ONNX models for Roboflow classification inference.

    Attributes:
        multiclass (bool): Whether the classification is multi-class or not.

    Methods:
        get_infer_bucket_file_list() -> list: Get the list of required files for inference.
        softmax(x): Compute softmax values for a given set of scores.
        infer(request: ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response.
        draw_predictions(inference_request, inference_response): Draw prediction visuals on an image.
    """

    task_type = "classification"

    preprocess_means = [0.5, 0.5, 0.5]
    preprocess_stds = [0.5, 0.5, 0.5]

    def __init__(self, *args, **kwargs):
        """Initialize the model, setting whether it is multiclass or not."""
        super().__init__(*args, **kwargs)
        self.multiclass = self.environment.get("MULTICLASS", False)

    def draw_predictions(self, inference_request, inference_response):
        """Draw prediction visuals on an image.

        This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.

        Args:
            inference_request: The request object containing the image and parameters.
            inference_response: The response object containing the predictions and other details.

        Returns:
            bytes: The bytes of the visualized image in JPEG format.
        """
        image = load_image_rgb(inference_request.image)
        image = Image.fromarray(image)
        draw = ImageDraw.Draw(image)
        font = ImageFont.load_default()
        if isinstance(inference_response.predictions, list):
            prediction = inference_response.predictions[0]
            color = self.colors.get(prediction.class_name, "#4892EA")
            draw.rectangle(
                [0, 0, image.size[1], image.size[0]],
                outline=color,
                width=inference_request.visualization_stroke_width,
            )
            text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}"
            text_size = font.getbbox(text)

            # set button size + 10px margins
            button_size = (text_size[2] + 20, text_size[3] + 20)
            button_img = Image.new("RGBA", button_size, color)
            # put text on button with 10px margins
            button_draw = ImageDraw.Draw(button_img)
            button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

            # put button on source image in position (0, 0)
            image.paste(button_img, (0, 0))
        else:
            if len(inference_response.predictions) > 0:
                box_color = "#4892EA"
                draw.rectangle(
                    [0, 0, image.size[1], image.size[0]],
                    outline=box_color,
                    width=inference_request.visualization_stroke_width,
                )
            row = 0
            predictions = [
                (cls_name, pred)
                for cls_name, pred in inference_response.predictions.items()
            ]
            predictions = sorted(
                predictions, key=lambda x: x[1].confidence, reverse=True
            )
            for i, (cls_name, pred) in enumerate(predictions):
                color = self.colors.get(cls_name, "#4892EA")
                text = f"{cls_name} {pred.confidence:.2f}"
                text_size = font.getbbox(text)

                # set button size + 10px margins
                button_size = (text_size[2] + 20, text_size[3] + 20)
                button_img = Image.new("RGBA", button_size, color)
                # put text on button with 10px margins
                button_draw = ImageDraw.Draw(button_img)
                button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

                # put button on source image in position (0, 0)
                image.paste(button_img, (0, row))
                row += button_size[1]

        buffered = BytesIO()
        image = image.convert("RGB")
        image.save(buffered, format="JPEG")
        return buffered.getvalue()

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["environment.json"].
        """
        return ["environment.json"]

    def infer(
        self,
        image: Any,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
        return_image_dims: bool = False,
        **kwargs,
    ):
        """
        Perform inference on the provided image(s) and return the predictions.

        Args:
            image (Any): The image or list of images to be processed.
                - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
            disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
            disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
            return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False.
            **kwargs: Additional parameters to customize the inference process.

        Returns:
            Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]:
            If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned.
            If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned.
            If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned.
            If `return_image_dims` is False and a single image is provided, only the prediction array is returned.

        Notes:
            - The input image(s) will be preprocessed (normalized and reshaped) before inference.
            - This function uses an ONNX session to perform inference on the input image(s).
        """
        return super().infer(
            image,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
            return_image_dims=return_image_dims,
            **kwargs,
        )

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        return_image_dims=False,
        **kwargs,
    ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]:
        predictions = predictions[0]
        return self.make_response(
            predictions, preprocess_return_metadata["img_dims"], **kwargs
        )

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        return (predictions,)

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        if isinstance(image, list):
            imgs_with_dims = [
                self.preproc_image(
                    i,
                    disable_preproc_auto_orient=kwargs.get(
                        "disable_preproc_auto_orient", False
                    ),
                    disable_preproc_contrast=kwargs.get(
                        "disable_preproc_contrast", False
                    ),
                    disable_preproc_grayscale=kwargs.get(
                        "disable_preproc_grayscale", False
                    ),
                    disable_preproc_static_crop=kwargs.get(
                        "disable_preproc_static_crop", False
                    ),
                )
                for i in image
            ]
            imgs, img_dims = zip(*imgs_with_dims)
            if isinstance(imgs[0], np.ndarray):
                img_in = np.concatenate(imgs, axis=0)
            elif USE_PYTORCH_FOR_PREPROCESSING:
                img_in = torch.cat(imgs, dim=0)
            else:
                raise ValueError(
                    f"Received a list of images of unknown type, {type(imgs[0])}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )
        else:
            img_in, img_dims = self.preproc_image(
                image,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
                disable_preproc_contrast=kwargs.get("disable_preproc_contrast", False),
                disable_preproc_grayscale=kwargs.get(
                    "disable_preproc_grayscale", False
                ),
                disable_preproc_static_crop=kwargs.get(
                    "disable_preproc_static_crop", False
                ),
            )
            img_dims = [img_dims]

        img_in /= 255.0

        mean = self.preprocess_means
        std = self.preprocess_stds
        if isinstance(img_in, np.ndarray):
            img_in = img_in.astype(np.float32)
        elif USE_PYTORCH_FOR_PREPROCESSING:
            img_in = img_in.float()
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(img_in)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

        img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[0]) / std[0]
        img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1]
        img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[2]) / std[2]
        return img_in, PreprocessReturnMetadata({"img_dims": img_dims})

    def infer_from_request(
        self,
        request: ClassificationInferenceRequest,
    ) -> Union[List[InferenceResponse], InferenceResponse]:
        """
        Handle an inference request to produce an appropriate response.

        Args:
            request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters.

        Returns:
            Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

        Notes:
            - Starts a timer at the beginning to calculate inference time.
            - Processes the image(s) through the `infer` method.
            - Generates the appropriate response object(s) using `make_response`.
            - Calculates and sets the time taken for inference.
            - If visualization is requested, the predictions are drawn on the image.
        """
        t1 = perf_counter()
        responses = self.infer(**request.dict(), return_image_dims=True)
        for response in responses:
            response.time = perf_counter() - t1
            response.inference_id = getattr(request, "id", None)

        if request.visualize_predictions:
            for response in responses:
                response.visualization = self.draw_predictions(request, response)

        if not isinstance(request.image, list):
            responses = responses[0]

        return responses

    def make_response(
        self,
        predictions,
        img_dims,
        confidence: float = 0.5,
        **kwargs,
    ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]:
        """
        Create response objects for the given predictions and image dimensions.

        Args:
            predictions (list): List of prediction arrays from the inference process.
            img_dims (list): List of tuples indicating the dimensions (width, height) of each image.
            confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5.
            **kwargs: Additional parameters to influence the response creation process.

        Returns:
            Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details.

        Notes:
            - If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image.
            - If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image.
            - Predictions below the confidence threshold are filtered out.
        """
        responses = []
        confidence_threshold = float(confidence)
        for ind, prediction in enumerate(predictions):
            if self.multiclass:
                preds = prediction[0]
                results = dict()
                predicted_classes = []
                for i, o in enumerate(preds):
                    cls_name = self.class_names[i]
                    score = float(o)
                    results[cls_name] = {"confidence": score, "class_id": i}
                    if score > confidence_threshold:
                        predicted_classes.append(cls_name)
                response = MultiLabelClassificationInferenceResponse(
                    image=InferenceResponseImage(
                        width=img_dims[ind][0], height=img_dims[ind][1]
                    ),
                    predicted_classes=predicted_classes,
                    predictions=results,
                )
            else:
                preds = prediction[0]
                preds = self.softmax(preds)
                results = []
                for i, cls_name in enumerate(self.class_names):
                    score = float(preds[i])
                    if score < confidence_threshold:
                        continue
                    pred = {
                        "class_id": i,
                        "class": cls_name,
                        "confidence": round(score, 4),
                    }
                    results.append(pred)
                results = sorted(results, key=lambda x: x["confidence"], reverse=True)

                response = ClassificationInferenceResponse(
                    image=InferenceResponseImage(
                        width=img_dims[ind][1], height=img_dims[ind][0]
                    ),
                    predictions=results,
                    top=results[0]["class"] if results else "",
                    confidence=results[0]["confidence"] if results else 0.0,
                )
            responses.append(response)

        return responses

    @staticmethod
    def softmax(x):
        """Compute softmax values for each set of scores in x.

        Args:
            x (np.array): The input array containing the scores.

        Returns:
            np.array: The softmax values for each set of scores.
        """
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()

    def get_model_output_shape(self) -> Tuple[int, int, int]:
        test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8)
        test_image, _ = self.preprocess(test_image)
        output = np.array(self.predict(test_image))
        return output.shape

    def validate_model_classes(self) -> None:
        output_shape = self.get_model_output_shape()
        num_classes = output_shape[3]
        try:
            assert num_classes == self.num_classes
        except AssertionError:
            raise ValueError(
                f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
            )
Functions
__init__
__init__(*args, **kwargs)

Initialize the model, setting whether it is multiclass or not.

Source code in inference/core/models/classification_base.py
45
46
47
48
def __init__(self, *args, **kwargs):
    """Initialize the model, setting whether it is multiclass or not."""
    super().__init__(*args, **kwargs)
    self.multiclass = self.environment.get("MULTICLASS", False)
draw_predictions
draw_predictions(inference_request, inference_response)

Draw prediction visuals on an image.

This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.

Parameters:

Name Type Description Default
inference_request

The request object containing the image and parameters.

required
inference_response

The response object containing the predictions and other details.

required

Returns:

Name Type Description
bytes

The bytes of the visualized image in JPEG format.

Source code in inference/core/models/classification_base.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def draw_predictions(self, inference_request, inference_response):
    """Draw prediction visuals on an image.

    This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.

    Args:
        inference_request: The request object containing the image and parameters.
        inference_response: The response object containing the predictions and other details.

    Returns:
        bytes: The bytes of the visualized image in JPEG format.
    """
    image = load_image_rgb(inference_request.image)
    image = Image.fromarray(image)
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    if isinstance(inference_response.predictions, list):
        prediction = inference_response.predictions[0]
        color = self.colors.get(prediction.class_name, "#4892EA")
        draw.rectangle(
            [0, 0, image.size[1], image.size[0]],
            outline=color,
            width=inference_request.visualization_stroke_width,
        )
        text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}"
        text_size = font.getbbox(text)

        # set button size + 10px margins
        button_size = (text_size[2] + 20, text_size[3] + 20)
        button_img = Image.new("RGBA", button_size, color)
        # put text on button with 10px margins
        button_draw = ImageDraw.Draw(button_img)
        button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

        # put button on source image in position (0, 0)
        image.paste(button_img, (0, 0))
    else:
        if len(inference_response.predictions) > 0:
            box_color = "#4892EA"
            draw.rectangle(
                [0, 0, image.size[1], image.size[0]],
                outline=box_color,
                width=inference_request.visualization_stroke_width,
            )
        row = 0
        predictions = [
            (cls_name, pred)
            for cls_name, pred in inference_response.predictions.items()
        ]
        predictions = sorted(
            predictions, key=lambda x: x[1].confidence, reverse=True
        )
        for i, (cls_name, pred) in enumerate(predictions):
            color = self.colors.get(cls_name, "#4892EA")
            text = f"{cls_name} {pred.confidence:.2f}"
            text_size = font.getbbox(text)

            # set button size + 10px margins
            button_size = (text_size[2] + 20, text_size[3] + 20)
            button_img = Image.new("RGBA", button_size, color)
            # put text on button with 10px margins
            button_draw = ImageDraw.Draw(button_img)
            button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

            # put button on source image in position (0, 0)
            image.paste(button_img, (0, row))
            row += button_size[1]

    buffered = BytesIO()
    image = image.convert("RGB")
    image.save(buffered, format="JPEG")
    return buffered.getvalue()
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["environment.json"].

Source code in inference/core/models/classification_base.py
123
124
125
126
127
128
129
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["environment.json"].
    """
    return ["environment.json"]
infer
infer(
    image,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
    return_image_dims=False,
    **kwargs
)

Perform inference on the provided image(s) and return the predictions.

Parameters:

Name Type Description Default
image Any

The image or list of images to be processed. - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

required
disable_preproc_auto_orient bool

If true, the auto orient preprocessing step is disabled for this call. Default is False.

False
disable_preproc_contrast bool

If true, the auto contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False
return_image_dims bool

If set to True, the function will also return the dimensions of the image. Defaults to False.

False
**kwargs

Additional parameters to customize the inference process.

{}

Returns:

Type Description

Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]:

If return_image_dims is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned.

If return_image_dims is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned.

If return_image_dims is False and a list of images is provided, only the list of prediction arrays is returned.

If return_image_dims is False and a single image is provided, only the prediction array is returned.

Notes
  • The input image(s) will be preprocessed (normalized and reshaped) before inference.
  • This function uses an ONNX session to perform inference on the input image(s).
Source code in inference/core/models/classification_base.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def infer(
    self,
    image: Any,
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
    return_image_dims: bool = False,
    **kwargs,
):
    """
    Perform inference on the provided image(s) and return the predictions.

    Args:
        image (Any): The image or list of images to be processed.
            - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
        return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False.
        **kwargs: Additional parameters to customize the inference process.

    Returns:
        Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]:
        If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned.
        If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned.
        If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned.
        If `return_image_dims` is False and a single image is provided, only the prediction array is returned.

    Notes:
        - The input image(s) will be preprocessed (normalized and reshaped) before inference.
        - This function uses an ONNX session to perform inference on the input image(s).
    """
    return super().infer(
        image,
        disable_preproc_auto_orient=disable_preproc_auto_orient,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
        return_image_dims=return_image_dims,
        **kwargs,
    )
infer_from_request
infer_from_request(request)

Handle an inference request to produce an appropriate response.

Parameters:

Name Type Description Default
request ClassificationInferenceRequest

The request object encapsulating the image(s) and relevant parameters.

required

Returns:

Type Description
Union[List[InferenceResponse], InferenceResponse]

Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

Notes
  • Starts a timer at the beginning to calculate inference time.
  • Processes the image(s) through the infer method.
  • Generates the appropriate response object(s) using make_response.
  • Calculates and sets the time taken for inference.
  • If visualization is requested, the predictions are drawn on the image.
Source code in inference/core/models/classification_base.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def infer_from_request(
    self,
    request: ClassificationInferenceRequest,
) -> Union[List[InferenceResponse], InferenceResponse]:
    """
    Handle an inference request to produce an appropriate response.

    Args:
        request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters.

    Returns:
        Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

    Notes:
        - Starts a timer at the beginning to calculate inference time.
        - Processes the image(s) through the `infer` method.
        - Generates the appropriate response object(s) using `make_response`.
        - Calculates and sets the time taken for inference.
        - If visualization is requested, the predictions are drawn on the image.
    """
    t1 = perf_counter()
    responses = self.infer(**request.dict(), return_image_dims=True)
    for response in responses:
        response.time = perf_counter() - t1
        response.inference_id = getattr(request, "id", None)

    if request.visualize_predictions:
        for response in responses:
            response.visualization = self.draw_predictions(request, response)

    if not isinstance(request.image, list):
        responses = responses[0]

    return responses
make_response
make_response(
    predictions, img_dims, confidence=0.5, **kwargs
)

Create response objects for the given predictions and image dimensions.

Parameters:

Name Type Description Default
predictions list

List of prediction arrays from the inference process.

required
img_dims list

List of tuples indicating the dimensions (width, height) of each image.

required
confidence float

Confidence threshold for filtering predictions. Defaults to 0.5.

0.5
**kwargs

Additional parameters to influence the response creation process.

{}

Returns:

Type Description
Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]

Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details.

Notes
  • If the model is multiclass, a MultiLabelClassificationInferenceResponse is generated for each image.
  • If the model is not multiclass, a ClassificationInferenceResponse is generated for each image.
  • Predictions below the confidence threshold are filtered out.
Source code in inference/core/models/classification_base.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def make_response(
    self,
    predictions,
    img_dims,
    confidence: float = 0.5,
    **kwargs,
) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]:
    """
    Create response objects for the given predictions and image dimensions.

    Args:
        predictions (list): List of prediction arrays from the inference process.
        img_dims (list): List of tuples indicating the dimensions (width, height) of each image.
        confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5.
        **kwargs: Additional parameters to influence the response creation process.

    Returns:
        Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details.

    Notes:
        - If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image.
        - If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image.
        - Predictions below the confidence threshold are filtered out.
    """
    responses = []
    confidence_threshold = float(confidence)
    for ind, prediction in enumerate(predictions):
        if self.multiclass:
            preds = prediction[0]
            results = dict()
            predicted_classes = []
            for i, o in enumerate(preds):
                cls_name = self.class_names[i]
                score = float(o)
                results[cls_name] = {"confidence": score, "class_id": i}
                if score > confidence_threshold:
                    predicted_classes.append(cls_name)
            response = MultiLabelClassificationInferenceResponse(
                image=InferenceResponseImage(
                    width=img_dims[ind][0], height=img_dims[ind][1]
                ),
                predicted_classes=predicted_classes,
                predictions=results,
            )
        else:
            preds = prediction[0]
            preds = self.softmax(preds)
            results = []
            for i, cls_name in enumerate(self.class_names):
                score = float(preds[i])
                if score < confidence_threshold:
                    continue
                pred = {
                    "class_id": i,
                    "class": cls_name,
                    "confidence": round(score, 4),
                }
                results.append(pred)
            results = sorted(results, key=lambda x: x["confidence"], reverse=True)

            response = ClassificationInferenceResponse(
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
                predictions=results,
                top=results[0]["class"] if results else "",
                confidence=results[0]["confidence"] if results else 0.0,
            )
        responses.append(response)

    return responses
softmax staticmethod
softmax(x)

Compute softmax values for each set of scores in x.

Parameters:

Name Type Description Default
x array

The input array containing the scores.

required

Returns:

Type Description

np.array: The softmax values for each set of scores.

Source code in inference/core/models/classification_base.py
370
371
372
373
374
375
376
377
378
379
380
381
@staticmethod
def softmax(x):
    """Compute softmax values for each set of scores in x.

    Args:
        x (np.array): The input array containing the scores.

    Returns:
        np.array: The softmax values for each set of scores.
    """
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

inference.core.models.inference_models_adapters

Classes

InferenceModelsClassificationAdapter

Bases: Model

Source code in inference/core/models/inference_models_adapters.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
class InferenceModelsClassificationAdapter(Model):
    def __init__(self, model_id: str, api_key: str = None, **kwargs):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)

        self.task_type = "classification"
        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: Union[ClassificationModel, MultiLabelClassificationModel] = (
            AutoModel.from_pretrained(
                model_id_or_path=model_id,
                api_key=self.api_key,
                allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
                allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
                weights_provider_extra_headers=extra_weights_provider_headers,
                backend=backend,
                **kwargs,
            )
        )
        self.class_names = list(self._model.class_names)

    def map_inference_kwargs(self, kwargs: dict) -> dict:
        return kwargs

    def preprocess(self, image: Any, **kwargs):
        is_batch = isinstance(image, list)
        images = image if is_batch else [image]
        np_images: List[np.ndarray] = [
            load_image_bgr(
                v,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
            )
            for v in images
        ]
        images_shapes = [i.shape[:2] for i in np_images]
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.pre_process(np_images, **mapped_kwargs), images_shapes

    def predict(self, img_in, **kwargs):
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.forward(img_in, **mapped_kwargs)

    def postprocess(
        self,
        predictions: Tuple[List[KeyPoints], Optional[List[Detections]]],
        returned_metadata: List[Tuple[int, int]],
        **kwargs,
    ) -> Union[
        List[MultiLabelClassificationInferenceResponse],
        List[ClassificationInferenceResponse],
    ]:
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        post_processed_predictions = self._model.post_process(
            predictions, **mapped_kwargs
        )
        if isinstance(post_processed_predictions, list):
            # multi-label classification
            return prepare_multi_label_classification_response(
                post_processed_predictions,
                image_sizes=returned_metadata,
                class_names=self.class_names,
                confidence_threshold=kwargs.get("confidence", 0.5),
            )
        else:
            # single-label classification
            return prepare_classification_response(
                post_processed_predictions,
                image_sizes=returned_metadata,
                class_names=self.class_names,
                confidence_threshold=kwargs.get("confidence", 0.5),
            )

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def infer_from_request(
        self,
        request: ClassificationInferenceRequest,
    ) -> Union[List[InferenceResponse], InferenceResponse]:
        """
        Handle an inference request to produce an appropriate response.

        Args:
            request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters.

        Returns:
            Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

        Notes:
            - Starts a timer at the beginning to calculate inference time.
            - Processes the image(s) through the `infer` method.
            - Generates the appropriate response object(s) using `make_response`.
            - Calculates and sets the time taken for inference.
            - If visualization is requested, the predictions are drawn on the image.
        """
        t1 = perf_counter()
        responses = self.infer(**request.dict(), return_image_dims=True)
        for response in responses:
            response.time = perf_counter() - t1
            response.inference_id = getattr(request, "id", None)

        if request.visualize_predictions:
            for response in responses:
                response.visualization = draw_predictions(
                    request, response, self.class_names
                )

        if not isinstance(request.image, list):
            responses = responses[0]

        return responses
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/inference_models_adapters.py
672
673
674
675
676
677
678
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass
infer_from_request
infer_from_request(request)

Handle an inference request to produce an appropriate response.

Parameters:

Name Type Description Default
request ClassificationInferenceRequest

The request object encapsulating the image(s) and relevant parameters.

required

Returns:

Type Description
Union[List[InferenceResponse], InferenceResponse]

Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

Notes
  • Starts a timer at the beginning to calculate inference time.
  • Processes the image(s) through the infer method.
  • Generates the appropriate response object(s) using make_response.
  • Calculates and sets the time taken for inference.
  • If visualization is requested, the predictions are drawn on the image.
Source code in inference/core/models/inference_models_adapters.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
def infer_from_request(
    self,
    request: ClassificationInferenceRequest,
) -> Union[List[InferenceResponse], InferenceResponse]:
    """
    Handle an inference request to produce an appropriate response.

    Args:
        request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters.

    Returns:
        Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned.

    Notes:
        - Starts a timer at the beginning to calculate inference time.
        - Processes the image(s) through the `infer` method.
        - Generates the appropriate response object(s) using `make_response`.
        - Calculates and sets the time taken for inference.
        - If visualization is requested, the predictions are drawn on the image.
    """
    t1 = perf_counter()
    responses = self.infer(**request.dict(), return_image_dims=True)
    for response in responses:
        response.time = perf_counter() - t1
        response.inference_id = getattr(request, "id", None)

    if request.visualize_predictions:
        for response in responses:
            response.visualization = draw_predictions(
                request, response, self.class_names
            )

    if not isinstance(request.image, list):
        responses = responses[0]

    return responses

InferenceModelsInstanceSegmentationAdapter

Bases: Model

Source code in inference/core/models/inference_models_adapters.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class InferenceModelsInstanceSegmentationAdapter(Model):
    def __init__(self, model_id: str, api_key: str = None, **kwargs):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)

        self.task_type = "instance-segmentation"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: InstanceSegmentationModel = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )
        self.class_names = list(self._model.class_names)

    def map_inference_kwargs(self, kwargs: dict) -> dict:
        return kwargs

    def preprocess(self, image: Any, **kwargs):
        is_batch = isinstance(image, list)
        images = image if is_batch else [image]
        np_images: List[np.ndarray] = [
            load_image_bgr(
                v,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
            )
            for v in images
        ]
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.pre_process(np_images, **mapped_kwargs)

    def predict(self, img_in, **kwargs):
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.forward(img_in, **mapped_kwargs)

    def postprocess(
        self,
        predictions: List[InstanceDetections],
        preprocess_return_metadata: PreprocessingMetadata,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        detections_list = self._model.post_process(
            predictions, preprocess_return_metadata, **mapped_kwargs
        )

        responses: List[InstanceSegmentationInferenceResponse] = []
        for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
            H = preproc_metadata.original_size.height
            W = preproc_metadata.original_size.width

            xyxy = det.xyxy.detach().cpu().numpy()
            confs = det.confidence.detach().cpu().numpy()
            masks = det.mask.detach().cpu().numpy()
            polys = masks2poly(masks)
            class_ids = det.class_id.detach().cpu().numpy()

            predictions: List[InstanceSegmentationPrediction] = []

            for (x1, y1, x2, y2), mask_as_poly, conf, class_id in zip(
                xyxy, polys, confs, class_ids
            ):
                cx = (float(x1) + float(x2)) / 2.0
                cy = (float(y1) + float(y2)) / 2.0
                w = float(x2) - float(x1)
                h = float(y2) - float(y1)
                class_id_int = int(class_id)
                class_name = (
                    self.class_names[class_id_int]
                    if 0 <= class_id_int < len(self.class_names)
                    else str(class_id_int)
                )
                if (
                    kwargs.get("class_filter")
                    and class_name not in kwargs["class_filter"]
                ):
                    continue
                predictions.append(
                    InstanceSegmentationPrediction(
                        x=cx,
                        y=cy,
                        width=w,
                        height=h,
                        confidence=float(conf),
                        points=[
                            Point(x=point[0], y=point[1]) for point in mask_as_poly
                        ],
                        **{"class": class_name},
                        class_id=class_id_int,
                    )
                )

            responses.append(
                InstanceSegmentationInferenceResponse(
                    predictions=predictions,
                    image=InferenceResponseImage(width=W, height=H),
                )
            )
        return responses

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def draw_predictions(
        self,
        inference_request: InferenceRequest,
        inference_response: InferenceResponse,
    ) -> bytes:
        """Draw predictions from an inference response onto the original image provided by an inference request

        Args:
            inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
            inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

        Returns:
            str: A base64 encoded image string
        """
        class_id_2_color = {
            i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
            for i, class_name in enumerate(self._model.class_names)
        }
        return draw_detection_predictions(
            inference_request=inference_request,
            inference_response=inference_response,
            colors=class_id_2_color,
        )
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/inference_models_adapters.py
350
351
352
353
354
355
356
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass
draw_predictions
draw_predictions(inference_request, inference_response)

Draw predictions from an inference response onto the original image provided by an inference request

Parameters:

Name Type Description Default
inference_request ObjectDetectionInferenceRequest

The inference request containing the image on which to draw predictions

required
inference_response ObjectDetectionInferenceResponse

The inference response containing predictions to be drawn

required

Returns:

Name Type Description
str bytes

A base64 encoded image string

Source code in inference/core/models/inference_models_adapters.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def draw_predictions(
    self,
    inference_request: InferenceRequest,
    inference_response: InferenceResponse,
) -> bytes:
    """Draw predictions from an inference response onto the original image provided by an inference request

    Args:
        inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
        inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

    Returns:
        str: A base64 encoded image string
    """
    class_id_2_color = {
        i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
        for i, class_name in enumerate(self._model.class_names)
    }
    return draw_detection_predictions(
        inference_request=inference_request,
        inference_response=inference_response,
        colors=class_id_2_color,
    )

InferenceModelsKeyPointsDetectionAdapter

Bases: Model

Source code in inference/core/models/inference_models_adapters.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
class InferenceModelsKeyPointsDetectionAdapter(Model):
    def __init__(self, model_id: str, api_key: str = None, **kwargs):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)

        self.task_type = "keypoint-detection"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: KeyPointsDetectionModel = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )
        self.class_names = list(self._model.class_names)

    def map_inference_kwargs(self, kwargs: dict) -> dict:
        if "request" in kwargs:
            keypoint_confidence_threshold = kwargs["request"].keypoint_confidence
            kwargs["key_points_threshold"] = keypoint_confidence_threshold
        return kwargs

    def preprocess(self, image: Any, **kwargs):
        is_batch = isinstance(image, list)
        images = image if is_batch else [image]
        np_images: List[np.ndarray] = [
            load_image_bgr(
                v,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
            )
            for v in images
        ]
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.pre_process(np_images, **mapped_kwargs)

    def predict(self, img_in, **kwargs):
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.forward(img_in, **mapped_kwargs)

    def postprocess(
        self,
        predictions: Tuple[List[KeyPoints], Optional[List[Detections]]],
        preprocess_return_metadata: PreprocessingMetadata,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        keypoints_list, detections_list = self._model.post_process(
            predictions, preprocess_return_metadata, **mapped_kwargs
        )
        if detections_list is None:
            raise RuntimeError(
                "Keypoints detection model does not provide instances detection - this is not supported for "
                "models from `inference-models` package which are adapted to work with `inference`."
            )
        key_points_classes = self._model.key_points_classes
        responses: List[KeypointsDetectionInferenceResponse] = []
        for preproc_metadata, keypoints, det in zip(
            preprocess_return_metadata, keypoints_list, detections_list
        ):

            H = preproc_metadata.original_size.height
            W = preproc_metadata.original_size.width

            xyxy = det.xyxy.detach().cpu().numpy()
            confs = det.confidence.detach().cpu().numpy()
            class_ids = det.class_id.detach().cpu().numpy()
            keypoints_xy = keypoints.xy.detach().cpu().tolist()
            keypoints_class_id = keypoints.class_id.detach().cpu().tolist()
            keypoints_confidence = keypoints.confidence.detach().cpu().tolist()
            predictions: List[KeypointsPrediction] = []

            for (
                (x1, y1, x2, y2),
                conf,
                class_id,
                instance_keypoints_xy,
                instance_keypoints_class_id,
                instance_keypoints_confidence,
            ) in zip(
                xyxy,
                confs,
                class_ids,
                keypoints_xy,
                keypoints_class_id,
                keypoints_confidence,
            ):
                cx = (float(x1) + float(x2)) / 2.0
                cy = (float(y1) + float(y2)) / 2.0
                w = float(x2) - float(x1)
                h = float(y2) - float(y1)
                class_id_int = int(class_id)
                class_name = (
                    self.class_names[class_id_int]
                    if 0 <= class_id_int < len(self.class_names)
                    else str(class_id_int)
                )
                if (
                    kwargs.get("class_filter")
                    and class_name not in kwargs["class_filter"]
                ):
                    continue
                predictions.append(
                    KeypointsPrediction(
                        x=cx,
                        y=cy,
                        width=w,
                        height=h,
                        confidence=float(conf),
                        **{"class": class_name},
                        class_id=class_id_int,
                        keypoints=model_keypoints_to_response(
                            instance_keypoints_xy=instance_keypoints_xy,
                            instance_keypoints_confidence=instance_keypoints_confidence,
                            instance_keypoints_class_id=instance_keypoints_class_id,
                            key_points_classes=key_points_classes,
                        ),
                    )
                )

            responses.append(
                KeypointsDetectionInferenceResponse(
                    predictions=predictions,
                    image=InferenceResponseImage(width=W, height=H),
                )
            )

        return responses

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def draw_predictions(
        self,
        inference_request: InferenceRequest,
        inference_response: InferenceResponse,
    ) -> bytes:
        """Draw predictions from an inference response onto the original image provided by an inference request

        Args:
            inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
            inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

        Returns:
            str: A base64 encoded image string
        """
        class_id_2_color = {
            i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
            for i, class_name in enumerate(self._model.class_names)
        }
        return draw_detection_predictions(
            inference_request=inference_request,
            inference_response=inference_response,
            colors=class_id_2_color,
        )
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/inference_models_adapters.py
528
529
530
531
532
533
534
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass
draw_predictions
draw_predictions(inference_request, inference_response)

Draw predictions from an inference response onto the original image provided by an inference request

Parameters:

Name Type Description Default
inference_request ObjectDetectionInferenceRequest

The inference request containing the image on which to draw predictions

required
inference_response ObjectDetectionInferenceResponse

The inference response containing predictions to be drawn

required

Returns:

Name Type Description
str bytes

A base64 encoded image string

Source code in inference/core/models/inference_models_adapters.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
def draw_predictions(
    self,
    inference_request: InferenceRequest,
    inference_response: InferenceResponse,
) -> bytes:
    """Draw predictions from an inference response onto the original image provided by an inference request

    Args:
        inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
        inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

    Returns:
        str: A base64 encoded image string
    """
    class_id_2_color = {
        i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
        for i, class_name in enumerate(self._model.class_names)
    }
    return draw_detection_predictions(
        inference_request=inference_request,
        inference_response=inference_response,
        colors=class_id_2_color,
    )

InferenceModelsObjectDetectionAdapter

Bases: Model

Source code in inference/core/models/inference_models_adapters.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class InferenceModelsObjectDetectionAdapter(Model):
    def __init__(self, model_id: str, api_key: str = None, **kwargs):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)

        self.task_type = "object-detection"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: ObjectDetectionModel = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )
        self.class_names = list(self._model.class_names)

    def map_inference_kwargs(self, kwargs: dict) -> dict:
        return kwargs

    def preprocess(self, image: Any, **kwargs):
        is_batch = isinstance(image, list)
        images = image if is_batch else [image]
        np_images: List[np.ndarray] = [
            load_image_bgr(
                v,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
            )
            for v in images
        ]
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.pre_process(np_images, **mapped_kwargs)

    def predict(self, img_in, **kwargs):
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.forward(img_in, **mapped_kwargs)

    def postprocess(
        self,
        predictions: List[Detections],
        preprocess_return_metadata: PreprocessingMetadata,
        **kwargs,
    ) -> List[ObjectDetectionInferenceResponse]:
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        detections_list = self._model.post_process(
            predictions, preprocess_return_metadata, **mapped_kwargs
        )

        responses: List[ObjectDetectionInferenceResponse] = []
        for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
            H = preproc_metadata.original_size.height
            W = preproc_metadata.original_size.width

            xyxy = det.xyxy.detach().cpu().numpy()
            confs = det.confidence.detach().cpu().numpy()
            class_ids = det.class_id.detach().cpu().numpy()

            predictions: List[ObjectDetectionPrediction] = []

            for (x1, y1, x2, y2), conf, class_id in zip(xyxy, confs, class_ids):
                cx = (float(x1) + float(x2)) / 2.0
                cy = (float(y1) + float(y2)) / 2.0
                w = float(x2) - float(x1)
                h = float(y2) - float(y1)
                class_id_int = int(class_id)
                class_name = (
                    self.class_names[class_id_int]
                    if 0 <= class_id_int < len(self.class_names)
                    else str(class_id_int)
                )
                if (
                    kwargs.get("class_filter")
                    and class_name not in kwargs["class_filter"]
                ):
                    continue
                predictions.append(
                    ObjectDetectionPrediction(
                        x=cx,
                        y=cy,
                        width=w,
                        height=h,
                        confidence=float(conf),
                        **{"class": class_name},
                        class_id=class_id_int,
                    )
                )

            responses.append(
                ObjectDetectionInferenceResponse(
                    predictions=predictions,
                    image=InferenceResponseImage(width=W, height=H),
                )
            )
        return responses

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def draw_predictions(
        self,
        inference_request: InferenceRequest,
        inference_response: InferenceResponse,
    ) -> bytes:
        """Draw predictions from an inference response onto the original image provided by an inference request

        Args:
            inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
            inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

        Returns:
            str: A base64 encoded image string
        """
        class_id_2_color = {
            i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
            for i, class_name in enumerate(self._model.class_names)
        }
        return draw_detection_predictions(
            inference_request=inference_request,
            inference_response=inference_response,
            colors=class_id_2_color,
        )
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/inference_models_adapters.py
199
200
201
202
203
204
205
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass
draw_predictions
draw_predictions(inference_request, inference_response)

Draw predictions from an inference response onto the original image provided by an inference request

Parameters:

Name Type Description Default
inference_request ObjectDetectionInferenceRequest

The inference request containing the image on which to draw predictions

required
inference_response ObjectDetectionInferenceResponse

The inference response containing predictions to be drawn

required

Returns:

Name Type Description
str bytes

A base64 encoded image string

Source code in inference/core/models/inference_models_adapters.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def draw_predictions(
    self,
    inference_request: InferenceRequest,
    inference_response: InferenceResponse,
) -> bytes:
    """Draw predictions from an inference response onto the original image provided by an inference request

    Args:
        inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
        inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

    Returns:
        str: A base64 encoded image string
    """
    class_id_2_color = {
        i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
        for i, class_name in enumerate(self._model.class_names)
    }
    return draw_detection_predictions(
        inference_request=inference_request,
        inference_response=inference_response,
        colors=class_id_2_color,
    )

InferenceModelsSemanticSegmentationAdapter

Bases: Model

Source code in inference/core/models/inference_models_adapters.py
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
class InferenceModelsSemanticSegmentationAdapter(Model):
    def __init__(self, model_id: str, api_key: str = None, **kwargs):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)

        self.task_type = "semantic-segmentation"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: SemanticSegmentationModel = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )
        self.class_names = list(self._model.class_names)

    @property
    def class_map(self):
        # match segment.roboflow.com
        return {str(k): v for k, v in enumerate(self.class_names)}

    def map_inference_kwargs(self, kwargs: dict) -> dict:
        return kwargs

    def preprocess(self, image: Any, **kwargs):
        is_batch = isinstance(image, list)
        images = image if is_batch else [image]
        np_images: List[np.ndarray] = [
            load_image_bgr(
                v,
                disable_preproc_auto_orient=kwargs.get(
                    "disable_preproc_auto_orient", False
                ),
            )
            for v in images
        ]
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.pre_process(np_images, **mapped_kwargs)

    def predict(self, img_in, **kwargs):
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        return self._model.forward(img_in, **mapped_kwargs)

    def postprocess(
        self,
        predictions: torch.Tensor,
        preprocess_return_metadata: PreprocessingMetadata,
        **kwargs,
    ) -> List[SemanticSegmentationInferenceResponse]:
        mapped_kwargs = self.map_inference_kwargs(kwargs)
        segmentation_results = self._model.post_process(
            predictions, preprocess_return_metadata, **mapped_kwargs
        )

        responses: List[SemanticSegmentationInferenceResponse] = []
        for preproc_metadata, segmentation in zip(
            preprocess_return_metadata, segmentation_results
        ):
            height = preproc_metadata.original_size.height
            width = preproc_metadata.original_size.width
            response_image = InferenceResponseImage(width=width, height=height)
            # WARNING! This way of conversion is hazardous - first of all, if background class is not in class names,
            # for certain pre-processing, we end up with -1 values which will be wrapped to 255 - second of all,
            # we can support only 256 classes - those constraints should be fine until inference 2.0
            response_predictions = SemanticSegmentationPrediction(
                segmentation_mask=self.img_to_b64_str(
                    segmentation.segmentation_map.to(torch.uint8)
                ),
                confidence_mask=self.img_to_b64_str(
                    (segmentation.confidence * 255).to(torch.uint8)
                ),
                class_map=self.class_map,
                image=dict(response_image),
            )
            response = SemanticSegmentationInferenceResponse(
                predictions=response_predictions,
                image=response_image,
            )
            responses.append(response)
        return responses

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        pass

    def img_to_b64_str(self, img: torch.Tensor) -> str:
        if img.dtype != torch.uint8:
            raise ValueError(
                f"img_to_b64_str requires uint8 tensor but got dtype {img.dtype}"
            )

        img = Image.fromarray(img.cpu().numpy())
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")

        img_str = base64.b64encode(buffered.getvalue())
        img_str = img_str.decode("ascii")

        return img_str

    def draw_predictions(
        self,
        inference_request: InferenceRequest,
        inference_response: InferenceResponse,
    ) -> bytes:
        raise NotImplementedError(
            "draw_predictions(...) is not implemented for semantic segmentation models - responses contain "
            "visualization already."
        )
Functions
clear_cache
clear_cache(delete_from_disk=True)

Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/inference_models_adapters.py
965
966
967
968
969
970
971
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    pass

Functions

draw_predictions

draw_predictions(
    inference_request, inference_response, class_names
)

Draw prediction visuals on an image.

This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.

Parameters:

Name Type Description Default
inference_request

The request object containing the image and parameters.

required
inference_response

The response object containing the predictions and other details.

required
class_names List[str]

List of class names corresponding to the model's classes.

required

Returns:

Name Type Description
bytes

The bytes of the visualized image in JPEG format.

Source code in inference/core/models/inference_models_adapters.py
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
def draw_predictions(inference_request, inference_response, class_names: List[str]):
    """Draw prediction visuals on an image.

    This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes.

    Args:
        inference_request: The request object containing the image and parameters.
        inference_response: The response object containing the predictions and other details.
        class_names: List of class names corresponding to the model's classes.

    Returns:
        bytes: The bytes of the visualized image in JPEG format.
    """
    image = load_image_rgb(inference_request.image)
    image = Image.fromarray(image)
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    class_id_2_color = {
        i: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
        for i, class_name in enumerate(class_names)
    }
    if isinstance(inference_response.predictions, list):
        prediction = inference_response.predictions[0]
        color = class_id_2_color.get(prediction.class_id, "#4892EA")
        draw.rectangle(
            [0, 0, image.size[1], image.size[0]],
            outline=color,
            width=inference_request.visualization_stroke_width,
        )
        text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}"
        text_size = font.getbbox(text)

        # set button size + 10px margins
        button_size = (text_size[2] + 20, text_size[3] + 20)
        button_img = Image.new("RGBA", button_size, color)
        # put text on button with 10px margins
        button_draw = ImageDraw.Draw(button_img)
        button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

        # put button on source image in position (0, 0)
        image.paste(button_img, (0, 0))
    else:
        if len(inference_response.predictions) > 0:
            box_color = "#4892EA"
            draw.rectangle(
                [0, 0, image.size[1], image.size[0]],
                outline=box_color,
                width=inference_request.visualization_stroke_width,
            )
        row = 0
        predictions = [
            (cls_name, pred)
            for cls_name, pred in inference_response.predictions.items()
        ]
        predictions = sorted(predictions, key=lambda x: x[1].confidence, reverse=True)
        for i, (cls_name, pred) in enumerate(predictions):
            color = class_id_2_color.get(cls_name, "#4892EA")
            text = f"{cls_name} {pred.confidence:.2f}"
            text_size = font.getbbox(text)

            # set button size + 10px margins
            button_size = (text_size[2] + 20, text_size[3] + 20)
            button_img = Image.new("RGBA", button_size, color)
            # put text on button with 10px margins
            button_draw = ImageDraw.Draw(button_img)
            button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255))

            # put button on source image in position (0, 0)
            image.paste(button_img, (0, row))
            row += button_size[1]

    buffered = BytesIO()
    image = image.convert("RGB")
    image.save(buffered, format="JPEG")
    return buffered.getvalue()

inference.core.models.instance_segmentation_base

Classes

InstanceSegmentationBaseOnnxRoboflowInferenceModel

Bases: OnnxRoboflowInferenceModel

Roboflow ONNX Instance Segmentation model.

This class implements an instance segmentation specific inference method for ONNX models provided by Roboflow.

Source code in inference/core/models/instance_segmentation_base.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
class InstanceSegmentationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel):
    """Roboflow ONNX Instance Segmentation model.

    This class implements an instance segmentation specific inference method
    for ONNX models provided by Roboflow.
    """

    task_type = "instance-segmentation"
    num_masks = 32

    def infer(
        self,
        image: Any,
        class_agnostic_nms: bool = False,
        confidence: float = DEFAULT_CONFIDENCE,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        mask_decode_mode: str = DEFAULT_MASK_DECODE_MODE,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        return_image_dims: bool = False,
        tradeoff_factor: float = DEFAULT_TRADEOFF_FACTOR,
        **kwargs,
    ) -> Union[PREDICTIONS_TYPE, Tuple[PREDICTIONS_TYPE, List[Tuple[int, int]]]]:
        """
        Process an image or list of images for instance segmentation.

        Args:
            image (Any): An image or a list of images for processing.
                - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
            class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False.
            confidence (float, optional): Confidence threshold for predictions. Defaults to 0.4.
            iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.3.
            mask_decode_mode (str, optional): Decoding mode for masks. Choices are "accurate", "tradeoff", and "fast". Defaults to "accurate".
            max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000.
            max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300.
            return_image_dims (bool, optional): Whether to return the dimensions of the processed images. Defaults to False.
            tradeoff_factor (float, optional): Tradeoff factor used when `mask_decode_mode` is set to "tradeoff". Must be in [0.0, 1.0]. Defaults to 0.5.
            disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
            disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
            **kwargs: Additional parameters to customize the inference process.

        Returns:
            Union[List[List[List[float]]], Tuple[List[List[List[float]]], List[Tuple[int, int]]]]: The list of predictions, with each prediction being a list of lists. Optionally, also returns the dimensions of the processed images.

        Raises:
            InvalidMaskDecodeArgument: If an invalid `mask_decode_mode` is provided or if the `tradeoff_factor` is outside the allowed range.

        Notes:
            - Processes input images and normalizes them.
            - Makes predictions using the ONNX runtime.
            - Applies non-maximum suppression to the predictions.
            - Decodes the masks according to the specified mode.
        """
        return super().infer(
            image,
            class_agnostic_nms=class_agnostic_nms,
            confidence=confidence,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
            iou_threshold=iou_threshold,
            mask_decode_mode=mask_decode_mode,
            max_candidates=max_candidates,
            max_detections=max_detections,
            return_image_dims=return_image_dims,
            tradeoff_factor=tradeoff_factor,
            **kwargs,
        )

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Union[
        InstanceSegmentationInferenceResponse,
        List[InstanceSegmentationInferenceResponse],
    ]:
        predictions, protos = predictions
        predictions = w_np_non_max_suppression(
            predictions,
            conf_thresh=kwargs["confidence"],
            iou_thresh=kwargs["iou_threshold"],
            class_agnostic=kwargs["class_agnostic_nms"],
            max_detections=kwargs["max_detections"],
            max_candidate_detections=kwargs["max_candidates"],
            num_masks=self.num_masks,
        )
        infer_shape = (self.img_size_h, self.img_size_w)
        masks = []
        mask_decode_mode = kwargs["mask_decode_mode"]
        tradeoff_factor = kwargs["tradeoff_factor"]
        img_in_shape = preprocess_return_metadata["im_shape"]

        predictions = [np.array(p) for p in predictions]

        for pred, proto, img_dim in zip(
            predictions, protos, preprocess_return_metadata["img_dims"]
        ):
            if pred.size == 0:
                masks.append([])
                continue
            if mask_decode_mode == "accurate":
                batch_masks = process_mask_accurate(
                    proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
                )
                output_mask_shape = img_in_shape[2:]
            elif mask_decode_mode == "tradeoff":
                if not 0 <= tradeoff_factor <= 1:
                    raise InvalidMaskDecodeArgument(
                        f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]"
                    )
                batch_masks = process_mask_tradeoff(
                    proto,
                    pred[:, 7:],
                    pred[:, :4],
                    img_in_shape[2:],
                    tradeoff_factor,
                )
                output_mask_shape = batch_masks.shape[1:]
            elif mask_decode_mode == "fast":
                batch_masks = process_mask_fast(
                    proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
                )
                output_mask_shape = batch_masks.shape[1:]
            else:
                raise InvalidMaskDecodeArgument(
                    f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']"
                )
            polys = masks2poly(batch_masks)
            pred[:, :4] = post_process_bboxes(
                [pred[:, :4]],
                infer_shape,
                [img_dim],
                self.preproc,
                resize_method=self.resize_method,
                disable_preproc_static_crop=preprocess_return_metadata[
                    "disable_preproc_static_crop"
                ],
            )[0]
            polys = post_process_polygons(
                img_dim,
                polys,
                output_mask_shape,
                self.preproc,
                resize_method=self.resize_method,
            )
            masks.append(polys)
        return self.make_response(
            predictions, masks, preprocess_return_metadata["img_dims"], **kwargs
        )

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        img_in, img_dims = self.load_image(
            image,
            disable_preproc_auto_orient=kwargs.get("disable_preproc_auto_orient"),
            disable_preproc_contrast=kwargs.get("disable_preproc_contrast"),
            disable_preproc_grayscale=kwargs.get("disable_preproc_grayscale"),
            disable_preproc_static_crop=kwargs.get("disable_preproc_static_crop"),
        )

        img_in /= 255.0
        return img_in, PreprocessReturnMetadata(
            {
                "img_dims": img_dims,
                "im_shape": img_in.shape,
                "disable_preproc_static_crop": kwargs.get(
                    "disable_preproc_static_crop"
                ),
            }
        )

    def make_response(
        self,
        predictions: List[List[List[float]]],
        masks: List[List[List[float]]],
        img_dims: List[Tuple[int, int]],
        class_filter: List[str] = [],
        **kwargs,
    ) -> Union[
        InstanceSegmentationInferenceResponse,
        List[InstanceSegmentationInferenceResponse],
    ]:
        """
        Create instance segmentation inference response objects for the provided predictions and masks.

        Args:
            predictions (List[List[List[float]]]): List of prediction data, one for each image.
            masks (List[List[List[float]]]): List of masks corresponding to the predictions.
            img_dims (List[Tuple[int, int]]): List of image dimensions corresponding to the processed images.
            class_filter (List[str], optional): List of class names to filter predictions by. Defaults to an empty list (no filtering).

        Returns:
            Union[InstanceSegmentationInferenceResponse, List[InstanceSegmentationInferenceResponse]]: A single instance segmentation response or a list of instance segmentation responses based on the number of processed images.

        Notes:
            - For each image, constructs an `InstanceSegmentationInferenceResponse` object.
            - Each response contains a list of `InstanceSegmentationPrediction` objects.
        """
        responses = []
        for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)):
            predictions = []
            for pred, mask in zip(batch_predictions, batch_masks):
                if class_filter and not self.class_names[int(pred[6])] in class_filter:
                    # TODO: logger.debug
                    continue
                # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                predictions.append(
                    InstanceSegmentationPrediction(
                        **{
                            "x": pred[0] + (pred[2] - pred[0]) / 2,
                            "y": pred[1] + (pred[3] - pred[1]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "points": [Point(x=point[0], y=point[1]) for point in mask],
                            "confidence": pred[4],
                            "class": self.class_names[int(pred[6])],
                            "class_id": int(pred[6]),
                        }
                    )
                )
            response = InstanceSegmentationInferenceResponse(
                predictions=predictions,
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
            responses.append(response)
        return responses

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Runs inference on the ONNX model.

        Args:
            img_in (np.ndarray): The preprocessed image(s) to run inference on.

        Returns:
            Tuple[np.ndarray, np.ndarray]: The ONNX model predictions and the ONNX model protos.

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError("predict must be implemented by a subclass")

    def validate_model_classes(self) -> None:
        output_shape = self.get_model_output_shape()
        num_classes = get_num_classes_from_model_prediction_shape(
            output_shape[2], masks=self.num_masks
        )
        try:
            assert num_classes == self.num_classes
        except AssertionError:
            raise ValueError(
                f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
            )
Functions
infer
infer(
    image,
    class_agnostic_nms=False,
    confidence=DEFAULT_CONFIDENCE,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
    iou_threshold=DEFAULT_IOU_THRESH,
    mask_decode_mode=DEFAULT_MASK_DECODE_MODE,
    max_candidates=DEFAULT_MAX_CANDIDATES,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    return_image_dims=False,
    tradeoff_factor=DEFAULT_TRADEOFF_FACTOR,
    **kwargs
)

Process an image or list of images for instance segmentation.

Parameters:

Name Type Description Default
image Any

An image or a list of images for processing. - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

required
class_agnostic_nms bool

Whether to use class-agnostic non-maximum suppression. Defaults to False.

False
confidence float

Confidence threshold for predictions. Defaults to 0.4.

DEFAULT_CONFIDENCE
iou_threshold float

IoU threshold for non-maximum suppression. Defaults to 0.3.

DEFAULT_IOU_THRESH
mask_decode_mode str

Decoding mode for masks. Choices are "accurate", "tradeoff", and "fast". Defaults to "accurate".

DEFAULT_MASK_DECODE_MODE
max_candidates int

Maximum number of candidate detections. Defaults to 3000.

DEFAULT_MAX_CANDIDATES
max_detections int

Maximum number of detections after non-maximum suppression. Defaults to 300.

DEFAUlT_MAX_DETECTIONS
return_image_dims bool

Whether to return the dimensions of the processed images. Defaults to False.

False
tradeoff_factor float

Tradeoff factor used when mask_decode_mode is set to "tradeoff". Must be in [0.0, 1.0]. Defaults to 0.5.

DEFAULT_TRADEOFF_FACTOR
disable_preproc_auto_orient bool

If true, the auto orient preprocessing step is disabled for this call. Default is False.

False
disable_preproc_contrast bool

If true, the auto contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False
**kwargs

Additional parameters to customize the inference process.

{}

Returns:

Type Description
Union[PREDICTIONS_TYPE, Tuple[PREDICTIONS_TYPE, List[Tuple[int, int]]]]

Union[List[List[List[float]]], Tuple[List[List[List[float]]], List[Tuple[int, int]]]]: The list of predictions, with each prediction being a list of lists. Optionally, also returns the dimensions of the processed images.

Raises:

Type Description
InvalidMaskDecodeArgument

If an invalid mask_decode_mode is provided or if the tradeoff_factor is outside the allowed range.

Notes
  • Processes input images and normalizes them.
  • Makes predictions using the ONNX runtime.
  • Applies non-maximum suppression to the predictions.
  • Decodes the masks according to the specified mode.
Source code in inference/core/models/instance_segmentation_base.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def infer(
    self,
    image: Any,
    class_agnostic_nms: bool = False,
    confidence: float = DEFAULT_CONFIDENCE,
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
    iou_threshold: float = DEFAULT_IOU_THRESH,
    mask_decode_mode: str = DEFAULT_MASK_DECODE_MODE,
    max_candidates: int = DEFAULT_MAX_CANDIDATES,
    max_detections: int = DEFAUlT_MAX_DETECTIONS,
    return_image_dims: bool = False,
    tradeoff_factor: float = DEFAULT_TRADEOFF_FACTOR,
    **kwargs,
) -> Union[PREDICTIONS_TYPE, Tuple[PREDICTIONS_TYPE, List[Tuple[int, int]]]]:
    """
    Process an image or list of images for instance segmentation.

    Args:
        image (Any): An image or a list of images for processing.
            - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False.
        confidence (float, optional): Confidence threshold for predictions. Defaults to 0.4.
        iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.3.
        mask_decode_mode (str, optional): Decoding mode for masks. Choices are "accurate", "tradeoff", and "fast". Defaults to "accurate".
        max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000.
        max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300.
        return_image_dims (bool, optional): Whether to return the dimensions of the processed images. Defaults to False.
        tradeoff_factor (float, optional): Tradeoff factor used when `mask_decode_mode` is set to "tradeoff". Must be in [0.0, 1.0]. Defaults to 0.5.
        disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
        **kwargs: Additional parameters to customize the inference process.

    Returns:
        Union[List[List[List[float]]], Tuple[List[List[List[float]]], List[Tuple[int, int]]]]: The list of predictions, with each prediction being a list of lists. Optionally, also returns the dimensions of the processed images.

    Raises:
        InvalidMaskDecodeArgument: If an invalid `mask_decode_mode` is provided or if the `tradeoff_factor` is outside the allowed range.

    Notes:
        - Processes input images and normalizes them.
        - Makes predictions using the ONNX runtime.
        - Applies non-maximum suppression to the predictions.
        - Decodes the masks according to the specified mode.
    """
    return super().infer(
        image,
        class_agnostic_nms=class_agnostic_nms,
        confidence=confidence,
        disable_preproc_auto_orient=disable_preproc_auto_orient,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
        iou_threshold=iou_threshold,
        mask_decode_mode=mask_decode_mode,
        max_candidates=max_candidates,
        max_detections=max_detections,
        return_image_dims=return_image_dims,
        tradeoff_factor=tradeoff_factor,
        **kwargs,
    )
make_response
make_response(
    predictions, masks, img_dims, class_filter=[], **kwargs
)

Create instance segmentation inference response objects for the provided predictions and masks.

Parameters:

Name Type Description Default
predictions List[List[List[float]]]

List of prediction data, one for each image.

required
masks List[List[List[float]]]

List of masks corresponding to the predictions.

required
img_dims List[Tuple[int, int]]

List of image dimensions corresponding to the processed images.

required
class_filter List[str]

List of class names to filter predictions by. Defaults to an empty list (no filtering).

[]

Returns:

Type Description
Union[InstanceSegmentationInferenceResponse, List[InstanceSegmentationInferenceResponse]]

Union[InstanceSegmentationInferenceResponse, List[InstanceSegmentationInferenceResponse]]: A single instance segmentation response or a list of instance segmentation responses based on the number of processed images.

Notes
  • For each image, constructs an InstanceSegmentationInferenceResponse object.
  • Each response contains a list of InstanceSegmentationPrediction objects.
Source code in inference/core/models/instance_segmentation_base.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def make_response(
    self,
    predictions: List[List[List[float]]],
    masks: List[List[List[float]]],
    img_dims: List[Tuple[int, int]],
    class_filter: List[str] = [],
    **kwargs,
) -> Union[
    InstanceSegmentationInferenceResponse,
    List[InstanceSegmentationInferenceResponse],
]:
    """
    Create instance segmentation inference response objects for the provided predictions and masks.

    Args:
        predictions (List[List[List[float]]]): List of prediction data, one for each image.
        masks (List[List[List[float]]]): List of masks corresponding to the predictions.
        img_dims (List[Tuple[int, int]]): List of image dimensions corresponding to the processed images.
        class_filter (List[str], optional): List of class names to filter predictions by. Defaults to an empty list (no filtering).

    Returns:
        Union[InstanceSegmentationInferenceResponse, List[InstanceSegmentationInferenceResponse]]: A single instance segmentation response or a list of instance segmentation responses based on the number of processed images.

    Notes:
        - For each image, constructs an `InstanceSegmentationInferenceResponse` object.
        - Each response contains a list of `InstanceSegmentationPrediction` objects.
    """
    responses = []
    for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)):
        predictions = []
        for pred, mask in zip(batch_predictions, batch_masks):
            if class_filter and not self.class_names[int(pred[6])] in class_filter:
                # TODO: logger.debug
                continue
            # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
            predictions.append(
                InstanceSegmentationPrediction(
                    **{
                        "x": pred[0] + (pred[2] - pred[0]) / 2,
                        "y": pred[1] + (pred[3] - pred[1]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "points": [Point(x=point[0], y=point[1]) for point in mask],
                        "confidence": pred[4],
                        "class": self.class_names[int(pred[6])],
                        "class_id": int(pred[6]),
                    }
                )
            )
        response = InstanceSegmentationInferenceResponse(
            predictions=predictions,
            image=InferenceResponseImage(
                width=img_dims[ind][1], height=img_dims[ind][0]
            ),
        )
        responses.append(response)
    return responses
predict
predict(img_in, **kwargs)

Runs inference on the ONNX model.

Parameters:

Name Type Description Default
img_in ndarray

The preprocessed image(s) to run inference on.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: The ONNX model predictions and the ONNX model protos.

Raises:

Type Description
NotImplementedError

This method must be implemented by a subclass.

Source code in inference/core/models/instance_segmentation_base.py
277
278
279
280
281
282
283
284
285
286
287
288
289
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
    """Runs inference on the ONNX model.

    Args:
        img_in (np.ndarray): The preprocessed image(s) to run inference on.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The ONNX model predictions and the ONNX model protos.

    Raises:
        NotImplementedError: This method must be implemented by a subclass.
    """
    raise NotImplementedError("predict must be implemented by a subclass")

Functions

inference.core.models.keypoints_detection_base

Classes

KeypointsDetectionBaseOnnxRoboflowInferenceModel

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection model. This class implements an object detection specific infer method.

Source code in inference/core/models/keypoints_detection_base.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class KeypointsDetectionBaseOnnxRoboflowInferenceModel(
    ObjectDetectionBaseOnnxRoboflowInferenceModel
):
    """Roboflow ONNX Object detection model. This class implements an object detection specific infer method."""

    task_type = "keypoint-detection"

    def __init__(self, model_id: str, *args, **kwargs):
        super().__init__(model_id, *args, **kwargs)

    def get_infer_bucket_file_list(self) -> list:
        """Returns the list of files to be downloaded from the inference bucket for ONNX model.

        Returns:
            list: A list of filenames specific to ONNX models.
        """
        return ["environment.json", "class_names.txt", "keypoints_metadata.json"]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preproc_return_metadata: PreprocessReturnMetadata,
        class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
        confidence: float = DEFAULT_CONFIDENCE,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        return_image_dims: bool = False,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Postprocesses the object detection predictions.

        Args:
            predictions (np.ndarray): Raw predictions from the model.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
            confidence (float): Confidence threshold for filtering detections. Default is 0.5.
            iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
            max_candidates (int): Maximum number of candidate detections. Default is 3000.
            max_detections (int): Maximum number of final detections. Default is 300.

        Returns:
            List[KeypointsDetectionInferenceResponse]: The post-processed predictions.
        """
        predictions = predictions[0]
        number_of_classes = len(self.get_class_names)
        num_masks = predictions.shape[2] - 5 - number_of_classes
        predictions = w_np_non_max_suppression(
            predictions,
            conf_thresh=confidence,
            iou_thresh=iou_threshold,
            class_agnostic=class_agnostic_nms,
            max_detections=max_detections,
            max_candidate_detections=max_candidates,
            num_masks=num_masks,
        )

        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preproc_return_metadata["img_dims"]
        predictions = post_process_bboxes(
            predictions=predictions,
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        predictions = post_process_keypoints(
            predictions=predictions,
            keypoints_start_index=-num_masks,
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        return self.make_response(predictions, img_dims, **kwargs)

    def make_response(
        self,
        predictions: List[List[float]],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Constructs object detection response objects based on predictions.

        Args:
            predictions (List[List[float]]): The list of predictions.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions.
        """
        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]
        keypoint_confidence_threshold = 0.0
        if "request" in kwargs:
            keypoint_confidence_threshold = kwargs["request"].keypoint_confidence
        responses = [
            KeypointsDetectionInferenceResponse(
                predictions=[
                    KeypointsPrediction(
                        # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "confidence": pred[4],
                            "class": self.class_names[int(pred[6])],
                            "class_id": int(pred[6]),
                            "keypoints": model_keypoints_to_response(
                                keypoints_metadata=self.keypoints_metadata,
                                keypoints=pred[7:],
                                predicted_object_class_id=int(pred[6]),
                                keypoint_confidence_threshold=keypoint_confidence_threshold,
                            ),
                        }
                    )
                    for pred in batch_predictions
                    if not class_filter
                    or self.class_names[int(pred[6])] in class_filter
                ],
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
            for ind, batch_predictions in enumerate(predictions)
        ]
        return responses

    def keypoints_count(self) -> int:
        raise NotImplementedError

    def validate_model_classes(self) -> None:
        num_keypoints = self.keypoints_count()
        output_shape = self.get_model_output_shape()
        num_classes = get_num_classes_from_model_prediction_shape(
            len_prediction=output_shape[2], keypoints=num_keypoints
        )
        if num_classes != self.num_classes:
            raise ValueError(
                f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
            )
Functions
get_infer_bucket_file_list
get_infer_bucket_file_list()

Returns the list of files to be downloaded from the inference bucket for ONNX model.

Returns:

Name Type Description
list list

A list of filenames specific to ONNX models.

Source code in inference/core/models/keypoints_detection_base.py
38
39
40
41
42
43
44
def get_infer_bucket_file_list(self) -> list:
    """Returns the list of files to be downloaded from the inference bucket for ONNX model.

    Returns:
        list: A list of filenames specific to ONNX models.
    """
    return ["environment.json", "class_names.txt", "keypoints_metadata.json"]
make_response
make_response(
    predictions,
    img_dims,
    class_filter=None,
    *args,
    **kwargs
)

Constructs object detection response objects based on predictions.

Parameters:

Name Type Description Default
predictions List[List[float]]

The list of predictions.

required
img_dims List[Tuple[int, int]]

Dimensions of the images.

required
class_filter Optional[List[str]]

A list of class names to filter, if provided.

None

Returns:

Type Description
List[KeypointsDetectionInferenceResponse]

List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions.

Source code in inference/core/models/keypoints_detection_base.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def make_response(
    self,
    predictions: List[List[float]],
    img_dims: List[Tuple[int, int]],
    class_filter: Optional[List[str]] = None,
    *args,
    **kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
    """Constructs object detection response objects based on predictions.

    Args:
        predictions (List[List[float]]): The list of predictions.
        img_dims (List[Tuple[int, int]]): Dimensions of the images.
        class_filter (Optional[List[str]]): A list of class names to filter, if provided.

    Returns:
        List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions.
    """
    if isinstance(img_dims, dict) and "img_dims" in img_dims:
        img_dims = img_dims["img_dims"]
    keypoint_confidence_threshold = 0.0
    if "request" in kwargs:
        keypoint_confidence_threshold = kwargs["request"].keypoint_confidence
    responses = [
        KeypointsDetectionInferenceResponse(
            predictions=[
                KeypointsPrediction(
                    # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "confidence": pred[4],
                        "class": self.class_names[int(pred[6])],
                        "class_id": int(pred[6]),
                        "keypoints": model_keypoints_to_response(
                            keypoints_metadata=self.keypoints_metadata,
                            keypoints=pred[7:],
                            predicted_object_class_id=int(pred[6]),
                            keypoint_confidence_threshold=keypoint_confidence_threshold,
                        ),
                    }
                )
                for pred in batch_predictions
                if not class_filter
                or self.class_names[int(pred[6])] in class_filter
            ],
            image=InferenceResponseImage(
                width=img_dims[ind][1], height=img_dims[ind][0]
            ),
        )
        for ind, batch_predictions in enumerate(predictions)
    ]
    return responses
postprocess
postprocess(
    predictions,
    preproc_return_metadata,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence=DEFAULT_CONFIDENCE,
    iou_threshold=DEFAULT_IOU_THRESH,
    max_candidates=DEFAULT_MAX_CANDIDATES,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    return_image_dims=False,
    **kwargs
)

Postprocesses the object detection predictions.

Parameters:

Name Type Description Default
predictions ndarray

Raw predictions from the model.

required
img_dims List[Tuple[int, int]]

Dimensions of the images.

required
class_agnostic_nms bool

Whether to apply class-agnostic non-max suppression. Default is False.

DEFAULT_CLASS_AGNOSTIC_NMS
confidence float

Confidence threshold for filtering detections. Default is 0.5.

DEFAULT_CONFIDENCE
iou_threshold float

IoU threshold for non-max suppression. Default is 0.5.

DEFAULT_IOU_THRESH
max_candidates int

Maximum number of candidate detections. Default is 3000.

DEFAULT_MAX_CANDIDATES
max_detections int

Maximum number of final detections. Default is 300.

DEFAUlT_MAX_DETECTIONS

Returns:

Type Description
List[KeypointsDetectionInferenceResponse]

List[KeypointsDetectionInferenceResponse]: The post-processed predictions.

Source code in inference/core/models/keypoints_detection_base.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def postprocess(
    self,
    predictions: Tuple[np.ndarray],
    preproc_return_metadata: PreprocessReturnMetadata,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence: float = DEFAULT_CONFIDENCE,
    iou_threshold: float = DEFAULT_IOU_THRESH,
    max_candidates: int = DEFAULT_MAX_CANDIDATES,
    max_detections: int = DEFAUlT_MAX_DETECTIONS,
    return_image_dims: bool = False,
    **kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
    """Postprocesses the object detection predictions.

    Args:
        predictions (np.ndarray): Raw predictions from the model.
        img_dims (List[Tuple[int, int]]): Dimensions of the images.
        class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
        confidence (float): Confidence threshold for filtering detections. Default is 0.5.
        iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
        max_candidates (int): Maximum number of candidate detections. Default is 3000.
        max_detections (int): Maximum number of final detections. Default is 300.

    Returns:
        List[KeypointsDetectionInferenceResponse]: The post-processed predictions.
    """
    predictions = predictions[0]
    number_of_classes = len(self.get_class_names)
    num_masks = predictions.shape[2] - 5 - number_of_classes
    predictions = w_np_non_max_suppression(
        predictions,
        conf_thresh=confidence,
        iou_thresh=iou_threshold,
        class_agnostic=class_agnostic_nms,
        max_detections=max_detections,
        max_candidate_detections=max_candidates,
        num_masks=num_masks,
    )

    infer_shape = (self.img_size_h, self.img_size_w)
    img_dims = preproc_return_metadata["img_dims"]
    predictions = post_process_bboxes(
        predictions=predictions,
        infer_shape=infer_shape,
        img_dims=img_dims,
        preproc=self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )
    predictions = post_process_keypoints(
        predictions=predictions,
        keypoints_start_index=-num_masks,
        infer_shape=infer_shape,
        img_dims=img_dims,
        preproc=self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )
    return self.make_response(predictions, img_dims, **kwargs)

Functions

inference.core.models.object_detection_base

Classes

ObjectDetectionBaseOnnxRoboflowInferenceModel

Bases: OnnxRoboflowInferenceModel

Roboflow ONNX Object detection model. This class implements an object detection specific infer method.

Source code in inference/core/models/object_detection_base.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
class ObjectDetectionBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model. This class implements an object detection specific infer method."""

    task_type = "object-detection"
    box_format = "xywh"

    def infer(
        self,
        image: Any,
        class_agnostic_nms: bool = DEFAULT_CLASS_AGNOSTIC_NMS,
        confidence: float = DEFAULT_CONFIDENCE,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        fix_batch_size: bool = False,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        return_image_dims: bool = False,
        **kwargs,
    ) -> Any:
        """
        Runs object detection inference on one or multiple images and returns the detections.

        Args:
            image (Any): The input image or a list of images to process.
                - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
            class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False.
            confidence (float, optional): Confidence threshold for predictions. Defaults to 0.4.
            iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.3.
            fix_batch_size (bool, optional): If True, fix the batch size for predictions. Useful when the model requires a fixed batch size. Defaults to False.
            max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000.
            max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300.
            return_image_dims (bool, optional): Whether to return the dimensions of the processed images along with the predictions. Defaults to False.
            disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
            disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            Union[List[ObjectDetectionInferenceResponse], ObjectDetectionInferenceResponse]: One or multiple object detection inference responses based on the number of processed images. Each response contains a list of predictions. If `return_image_dims` is True, it will return a tuple with predictions and image dimensions.

        Raises:
            ValueError: If batching is not enabled for the model and more than one image is passed for processing.
        """
        return super().infer(
            image,
            class_agnostic_nms=class_agnostic_nms,
            confidence=confidence,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
            iou_threshold=iou_threshold,
            fix_batch_size=fix_batch_size,
            max_candidates=max_candidates,
            max_detections=max_detections,
            return_image_dims=return_image_dims,
            **kwargs,
        )

    def make_response(
        self,
        predictions: List[List[float]],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[ObjectDetectionInferenceResponse]:
        """Constructs object detection response objects based on predictions.

        Args:
            predictions (List[List[float]]): The list of predictions.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions.
        """

        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]

        predictions = predictions[
            : len(img_dims)
        ]  # If the batch size was fixed we have empty preds at the end

        responses = [
            ObjectDetectionInferenceResponse(
                predictions=[
                    ObjectDetectionPrediction(
                        # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "confidence": pred[4],
                            "class": self.class_names[int(pred[6])],
                            "class_id": int(pred[6]),
                        }
                    )
                    for pred in batch_predictions
                    if not class_filter
                    or self.class_names[int(pred[6])] in class_filter
                ],
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
            for ind, batch_predictions in enumerate(predictions)
        ]
        return responses

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preproc_return_metadata: PreprocessReturnMetadata,
        class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
        confidence: float = DEFAULT_CONFIDENCE,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        return_image_dims: bool = False,
        **kwargs,
    ) -> List[ObjectDetectionInferenceResponse]:
        """Postprocesses the object detection predictions.

        Args:
            predictions (np.ndarray): Raw predictions from the model.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
            confidence (float): Confidence threshold for filtering detections. Default is 0.5.
            iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
            max_candidates (int): Maximum number of candidate detections. Default is 3000.
            max_detections (int): Maximum number of final detections. Default is 300.

        Returns:
            List[ObjectDetectionInferenceResponse]: The post-processed predictions.
        """
        predictions = predictions[0]
        predictions = w_np_non_max_suppression(
            predictions,
            conf_thresh=confidence,
            iou_thresh=iou_threshold,
            class_agnostic=class_agnostic_nms,
            max_detections=max_detections,
            max_candidate_detections=max_candidates,
            box_format=self.box_format,
        )

        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preproc_return_metadata["img_dims"]
        predictions = post_process_bboxes(
            predictions,
            infer_shape,
            img_dims,
            self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        return self.make_response(predictions, img_dims, **kwargs)

    def preprocess(
        self,
        image: Any,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
        fix_batch_size: bool = False,
        **kwargs,
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        """Preprocesses an object detection inference request.

        Args:
            request (ObjectDetectionInferenceRequest): The request object containing images.

        Returns:
            Tuple[np.ndarray, List[Tuple[int, int]]]: Preprocessed image inputs and corresponding dimensions.
        """
        img_in, img_dims = self.load_image(
            image,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )

        img_in /= 255.0

        if self.batching_enabled:
            batch_padding = 0
            if FIX_BATCH_SIZE or fix_batch_size:
                if MAX_BATCH_SIZE == float("inf"):
                    logger.warning(
                        "Requested fix_batch_size but MAX_BATCH_SIZE is not set. Using dynamic batching."
                    )
                    batch_padding = 0
                else:
                    batch_padding = MAX_BATCH_SIZE - img_in.shape[0]
            if batch_padding < 0:
                raise ValueError(
                    f"Requested fix_batch_size but passed in {img_in.shape[0]} images "
                    f"when the model's batch size is {MAX_BATCH_SIZE}\n"
                    f"Consider turning off fix_batch_size, changing `MAX_BATCH_SIZE` in"
                    f"your inference server config, or passing at most {MAX_BATCH_SIZE} images at a time"
                )
            width_remainder = img_in.shape[2] % 32
            height_remainder = img_in.shape[3] % 32
            if width_remainder > 0:
                width_padding = 32 - width_remainder
            else:
                width_padding = 0
            if height_remainder > 0:
                height_padding = 32 - height_remainder
            else:
                height_padding = 0

            if isinstance(img_in, np.ndarray):
                img_in = np.pad(
                    img_in,
                    (
                        (0, batch_padding),
                        (0, 0),
                        (0, width_padding),
                        (0, height_padding),
                    ),
                    "constant",
                )
            elif USE_PYTORCH_FOR_PREPROCESSING:
                img_in = torch.nn.functional.pad(
                    img_in,
                    (
                        0,
                        height_padding,  # height padding
                        0,
                        width_padding,  # width padding
                        0,
                        0,  # channels
                        0,
                        batch_padding,
                    ),  # batch
                    mode="constant",
                    value=0,
                )
            else:
                raise ValueError(
                    f"Received an image of unknown type, {type(img_in)}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )

        return img_in, PreprocessReturnMetadata(
            {
                "img_dims": img_dims,
                "disable_preproc_static_crop": disable_preproc_static_crop,
            }
        )

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        """Runs inference on the ONNX model.

        Args:
            img_in (np.ndarray): The preprocessed image(s) to run inference on.

        Returns:
            Tuple[np.ndarray]: The ONNX model predictions.

        Raises:
            NotImplementedError: This method must be implemented by a subclass.
        """
        raise NotImplementedError("predict must be implemented by a subclass")

    def validate_model_classes(self) -> None:
        output_shape = self.get_model_output_shape()
        num_classes = get_num_classes_from_model_prediction_shape(
            output_shape[2], masks=0
        )
        try:
            assert num_classes == self.num_classes
        except AssertionError:
            raise ValueError(
                f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
            )
Functions
infer
infer(
    image,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence=DEFAULT_CONFIDENCE,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
    iou_threshold=DEFAULT_IOU_THRESH,
    fix_batch_size=False,
    max_candidates=DEFAULT_MAX_CANDIDATES,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    return_image_dims=False,
    **kwargs
)

Runs object detection inference on one or multiple images and returns the detections.

Parameters:

Name Type Description Default
image Any

The input image or a list of images to process. - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

required
class_agnostic_nms bool

Whether to use class-agnostic non-maximum suppression. Defaults to False.

DEFAULT_CLASS_AGNOSTIC_NMS
confidence float

Confidence threshold for predictions. Defaults to 0.4.

DEFAULT_CONFIDENCE
iou_threshold float

IoU threshold for non-maximum suppression. Defaults to 0.3.

DEFAULT_IOU_THRESH
fix_batch_size bool

If True, fix the batch size for predictions. Useful when the model requires a fixed batch size. Defaults to False.

False
max_candidates int

Maximum number of candidate detections. Defaults to 3000.

DEFAULT_MAX_CANDIDATES
max_detections int

Maximum number of detections after non-maximum suppression. Defaults to 300.

DEFAUlT_MAX_DETECTIONS
return_image_dims bool

Whether to return the dimensions of the processed images along with the predictions. Defaults to False.

False
disable_preproc_auto_orient bool

If true, the auto orient preprocessing step is disabled for this call. Default is False.

False
disable_preproc_contrast bool

If true, the auto contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False
*args

Variable length argument list.

required
**kwargs

Arbitrary keyword arguments.

{}

Returns:

Type Description
Any

Union[List[ObjectDetectionInferenceResponse], ObjectDetectionInferenceResponse]: One or multiple object detection inference responses based on the number of processed images. Each response contains a list of predictions. If return_image_dims is True, it will return a tuple with predictions and image dimensions.

Raises:

Type Description
ValueError

If batching is not enabled for the model and more than one image is passed for processing.

Source code in inference/core/models/object_detection_base.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def infer(
    self,
    image: Any,
    class_agnostic_nms: bool = DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence: float = DEFAULT_CONFIDENCE,
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
    iou_threshold: float = DEFAULT_IOU_THRESH,
    fix_batch_size: bool = False,
    max_candidates: int = DEFAULT_MAX_CANDIDATES,
    max_detections: int = DEFAUlT_MAX_DETECTIONS,
    return_image_dims: bool = False,
    **kwargs,
) -> Any:
    """
    Runs object detection inference on one or multiple images and returns the detections.

    Args:
        image (Any): The input image or a list of images to process.
            - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False.
        confidence (float, optional): Confidence threshold for predictions. Defaults to 0.4.
        iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.3.
        fix_batch_size (bool, optional): If True, fix the batch size for predictions. Useful when the model requires a fixed batch size. Defaults to False.
        max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000.
        max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300.
        return_image_dims (bool, optional): Whether to return the dimensions of the processed images along with the predictions. Defaults to False.
        disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        Union[List[ObjectDetectionInferenceResponse], ObjectDetectionInferenceResponse]: One or multiple object detection inference responses based on the number of processed images. Each response contains a list of predictions. If `return_image_dims` is True, it will return a tuple with predictions and image dimensions.

    Raises:
        ValueError: If batching is not enabled for the model and more than one image is passed for processing.
    """
    return super().infer(
        image,
        class_agnostic_nms=class_agnostic_nms,
        confidence=confidence,
        disable_preproc_auto_orient=disable_preproc_auto_orient,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
        iou_threshold=iou_threshold,
        fix_batch_size=fix_batch_size,
        max_candidates=max_candidates,
        max_detections=max_detections,
        return_image_dims=return_image_dims,
        **kwargs,
    )
make_response
make_response(
    predictions,
    img_dims,
    class_filter=None,
    *args,
    **kwargs
)

Constructs object detection response objects based on predictions.

Parameters:

Name Type Description Default
predictions List[List[float]]

The list of predictions.

required
img_dims List[Tuple[int, int]]

Dimensions of the images.

required
class_filter Optional[List[str]]

A list of class names to filter, if provided.

None

Returns:

Type Description
List[ObjectDetectionInferenceResponse]

List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions.

Source code in inference/core/models/object_detection_base.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def make_response(
    self,
    predictions: List[List[float]],
    img_dims: List[Tuple[int, int]],
    class_filter: Optional[List[str]] = None,
    *args,
    **kwargs,
) -> List[ObjectDetectionInferenceResponse]:
    """Constructs object detection response objects based on predictions.

    Args:
        predictions (List[List[float]]): The list of predictions.
        img_dims (List[Tuple[int, int]]): Dimensions of the images.
        class_filter (Optional[List[str]]): A list of class names to filter, if provided.

    Returns:
        List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions.
    """

    if isinstance(img_dims, dict) and "img_dims" in img_dims:
        img_dims = img_dims["img_dims"]

    predictions = predictions[
        : len(img_dims)
    ]  # If the batch size was fixed we have empty preds at the end

    responses = [
        ObjectDetectionInferenceResponse(
            predictions=[
                ObjectDetectionPrediction(
                    # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "confidence": pred[4],
                        "class": self.class_names[int(pred[6])],
                        "class_id": int(pred[6]),
                    }
                )
                for pred in batch_predictions
                if not class_filter
                or self.class_names[int(pred[6])] in class_filter
            ],
            image=InferenceResponseImage(
                width=img_dims[ind][1], height=img_dims[ind][0]
            ),
        )
        for ind, batch_predictions in enumerate(predictions)
    ]
    return responses
postprocess
postprocess(
    predictions,
    preproc_return_metadata,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence=DEFAULT_CONFIDENCE,
    iou_threshold=DEFAULT_IOU_THRESH,
    max_candidates=DEFAULT_MAX_CANDIDATES,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    return_image_dims=False,
    **kwargs
)

Postprocesses the object detection predictions.

Parameters:

Name Type Description Default
predictions ndarray

Raw predictions from the model.

required
img_dims List[Tuple[int, int]]

Dimensions of the images.

required
class_agnostic_nms bool

Whether to apply class-agnostic non-max suppression. Default is False.

DEFAULT_CLASS_AGNOSTIC_NMS
confidence float

Confidence threshold for filtering detections. Default is 0.5.

DEFAULT_CONFIDENCE
iou_threshold float

IoU threshold for non-max suppression. Default is 0.5.

DEFAULT_IOU_THRESH
max_candidates int

Maximum number of candidate detections. Default is 3000.

DEFAULT_MAX_CANDIDATES
max_detections int

Maximum number of final detections. Default is 300.

DEFAUlT_MAX_DETECTIONS

Returns:

Type Description
List[ObjectDetectionInferenceResponse]

List[ObjectDetectionInferenceResponse]: The post-processed predictions.

Source code in inference/core/models/object_detection_base.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def postprocess(
    self,
    predictions: Tuple[np.ndarray, ...],
    preproc_return_metadata: PreprocessReturnMetadata,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    confidence: float = DEFAULT_CONFIDENCE,
    iou_threshold: float = DEFAULT_IOU_THRESH,
    max_candidates: int = DEFAULT_MAX_CANDIDATES,
    max_detections: int = DEFAUlT_MAX_DETECTIONS,
    return_image_dims: bool = False,
    **kwargs,
) -> List[ObjectDetectionInferenceResponse]:
    """Postprocesses the object detection predictions.

    Args:
        predictions (np.ndarray): Raw predictions from the model.
        img_dims (List[Tuple[int, int]]): Dimensions of the images.
        class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
        confidence (float): Confidence threshold for filtering detections. Default is 0.5.
        iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
        max_candidates (int): Maximum number of candidate detections. Default is 3000.
        max_detections (int): Maximum number of final detections. Default is 300.

    Returns:
        List[ObjectDetectionInferenceResponse]: The post-processed predictions.
    """
    predictions = predictions[0]
    predictions = w_np_non_max_suppression(
        predictions,
        conf_thresh=confidence,
        iou_thresh=iou_threshold,
        class_agnostic=class_agnostic_nms,
        max_detections=max_detections,
        max_candidate_detections=max_candidates,
        box_format=self.box_format,
    )

    infer_shape = (self.img_size_h, self.img_size_w)
    img_dims = preproc_return_metadata["img_dims"]
    predictions = post_process_bboxes(
        predictions,
        infer_shape,
        img_dims,
        self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )
    return self.make_response(predictions, img_dims, **kwargs)
predict
predict(img_in, **kwargs)

Runs inference on the ONNX model.

Parameters:

Name Type Description Default
img_in ndarray

The preprocessed image(s) to run inference on.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: The ONNX model predictions.

Raises:

Type Description
NotImplementedError

This method must be implemented by a subclass.

Source code in inference/core/models/object_detection_base.py
301
302
303
304
305
306
307
308
309
310
311
312
313
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
    """Runs inference on the ONNX model.

    Args:
        img_in (np.ndarray): The preprocessed image(s) to run inference on.

    Returns:
        Tuple[np.ndarray]: The ONNX model predictions.

    Raises:
        NotImplementedError: This method must be implemented by a subclass.
    """
    raise NotImplementedError("predict must be implemented by a subclass")
preprocess
preprocess(
    image,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
    fix_batch_size=False,
    **kwargs
)

Preprocesses an object detection inference request.

Parameters:

Name Type Description Default
request ObjectDetectionInferenceRequest

The request object containing images.

required

Returns:

Type Description
Tuple[ndarray, PreprocessReturnMetadata]

Tuple[np.ndarray, List[Tuple[int, int]]]: Preprocessed image inputs and corresponding dimensions.

Source code in inference/core/models/object_detection_base.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def preprocess(
    self,
    image: Any,
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
    fix_batch_size: bool = False,
    **kwargs,
) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
    """Preprocesses an object detection inference request.

    Args:
        request (ObjectDetectionInferenceRequest): The request object containing images.

    Returns:
        Tuple[np.ndarray, List[Tuple[int, int]]]: Preprocessed image inputs and corresponding dimensions.
    """
    img_in, img_dims = self.load_image(
        image,
        disable_preproc_auto_orient=disable_preproc_auto_orient,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
    )

    img_in /= 255.0

    if self.batching_enabled:
        batch_padding = 0
        if FIX_BATCH_SIZE or fix_batch_size:
            if MAX_BATCH_SIZE == float("inf"):
                logger.warning(
                    "Requested fix_batch_size but MAX_BATCH_SIZE is not set. Using dynamic batching."
                )
                batch_padding = 0
            else:
                batch_padding = MAX_BATCH_SIZE - img_in.shape[0]
        if batch_padding < 0:
            raise ValueError(
                f"Requested fix_batch_size but passed in {img_in.shape[0]} images "
                f"when the model's batch size is {MAX_BATCH_SIZE}\n"
                f"Consider turning off fix_batch_size, changing `MAX_BATCH_SIZE` in"
                f"your inference server config, or passing at most {MAX_BATCH_SIZE} images at a time"
            )
        width_remainder = img_in.shape[2] % 32
        height_remainder = img_in.shape[3] % 32
        if width_remainder > 0:
            width_padding = 32 - width_remainder
        else:
            width_padding = 0
        if height_remainder > 0:
            height_padding = 32 - height_remainder
        else:
            height_padding = 0

        if isinstance(img_in, np.ndarray):
            img_in = np.pad(
                img_in,
                (
                    (0, batch_padding),
                    (0, 0),
                    (0, width_padding),
                    (0, height_padding),
                ),
                "constant",
            )
        elif USE_PYTORCH_FOR_PREPROCESSING:
            img_in = torch.nn.functional.pad(
                img_in,
                (
                    0,
                    height_padding,  # height padding
                    0,
                    width_padding,  # width padding
                    0,
                    0,  # channels
                    0,
                    batch_padding,
                ),  # batch
                mode="constant",
                value=0,
            )
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(img_in)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

    return img_in, PreprocessReturnMetadata(
        {
            "img_dims": img_dims,
            "disable_preproc_static_crop": disable_preproc_static_crop,
        }
    )

Functions

inference.core.models.roboflow

Classes

OnnxRoboflowCoreModel

Bases: RoboflowCoreModel

Roboflow Inference Model that operates using an ONNX model file.

Source code in inference/core/models/roboflow.py
1058
1059
1060
1061
class OnnxRoboflowCoreModel(RoboflowCoreModel):
    """Roboflow Inference Model that operates using an ONNX model file."""

    pass

OnnxRoboflowInferenceModel

Bases: RoboflowInferenceModel

Roboflow Inference Model that operates using an ONNX model file.

Source code in inference/core/models/roboflow.py
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
class OnnxRoboflowInferenceModel(RoboflowInferenceModel):
    """Roboflow Inference Model that operates using an ONNX model file."""

    def __init__(
        self,
        model_id: str,
        onnxruntime_execution_providers: List[
            str
        ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS),
        *args,
        **kwargs,
    ):
        """Initializes the OnnxRoboflowInferenceModel instance.

        Args:
            model_id (str): The identifier for the specific ONNX model.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(model_id, *args, **kwargs)
        if self.load_weights or not self.has_model_metadata:
            self.onnxruntime_execution_providers = onnxruntime_execution_providers
            expanded_execution_providers = []
            for ep in self.onnxruntime_execution_providers:
                if ep == "TensorrtExecutionProvider":
                    ep = (
                        "TensorrtExecutionProvider",
                        {
                            "trt_engine_cache_enable": True,
                            "trt_engine_cache_path": os.path.join(
                                TENSORRT_CACHE_PATH, self.endpoint
                            ),
                            "trt_fp16_enable": True,
                        },
                    )
                expanded_execution_providers.append(ep)
            self.onnxruntime_execution_providers = expanded_execution_providers

        self.image_loader_threadpool = ThreadPoolExecutor(max_workers=None)
        self._session_lock = Lock()
        try:
            self.initialize_model(**kwargs)
            self.validate_model()
        except ModelArtefactError as e:
            logger.error(f"Unable to validate model artifacts, clearing cache: {e}")
            if DISK_CACHE_CLEANUP:
                self.clear_cache(delete_from_disk=True)
            else:
                logger.error("NOT deleting model from cache, inspect model artifacts")
            raise ModelArtefactError from e

    def infer(self, image: Any, **kwargs) -> Any:
        """Runs inference on given data.
        - image:
            can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        """
        input_elements = len(image) if isinstance(image, list) else 1
        max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size
        if (input_elements == 1) or (max_batch_size == float("inf")):
            return super().infer(image, **kwargs)
        logger.debug(
            f"Inference will be executed in batches, as there is {input_elements} input elements and "
            f"maximum batch size for a model is set to: {max_batch_size}"
        )
        inference_results = []
        for batch_input in create_batches(sequence=image, batch_size=max_batch_size):
            batch_inference_results = super().infer(batch_input, **kwargs)
            inference_results.append(batch_inference_results)
        return self.merge_inference_results(inference_results=inference_results)

    def merge_inference_results(self, inference_results: List[Any]) -> Any:
        return list(itertools.chain(*inference_results))

    def validate_model(self) -> None:
        if MODEL_VALIDATION_DISABLED:
            logger.debug("Model validation disabled.")
            return None
        logger.debug(f"Starting model validation for {self.endpoint}")
        validate_model_error_count = cache.get(
            self.endpoint + "_validate_model_error_count"
        )
        if validate_model_error_count is None:
            validate_model_error_count = 0
        if validate_model_error_count > 3:
            raise ModelArtefactError(
                "Model validation failed multiple times, ignoring this model."
            )
        if not self.load_weights:
            return
        try:
            assert self.onnx_session is not None
        except AssertionError as e:
            cache.set(
                self.endpoint + "_validate_model_error_count",
                validate_model_error_count + 1,
                expire=60,
            )
            raise ModelArtefactError(
                "ONNX session not initialized. Check that the model weights are available."
            ) from e
        try:
            self.run_test_inference()
        except Exception as e:
            cache.set(
                self.endpoint + "_validate_model_error_count",
                validate_model_error_count + 1,
                expire=60,
            )
            raise ModelArtefactError(f"Unable to run test inference. Cause: {e}") from e
        try:
            self.validate_model_classes()
        except Exception as e:
            cache.set(
                self.endpoint + "_validate_model_error_count",
                validate_model_error_count + 1,
                expire=60,
            )
            raise ModelArtefactError(
                f"Unable to validate model classes. Cause: {e}"
            ) from e
        logger.debug(f"Model validation finished for {self.endpoint}")
        cache.set(self.endpoint + "_validate_model_error_count", 0, expire=3600)

    def run_test_inference(self) -> None:
        test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8)
        logger.debug(f"Running test inference. Image size: {test_image.shape}")
        result = self.infer(test_image, usage_inference_test_run=True)
        logger.debug(f"Test inference finished.")
        return result

    def get_model_output_shape(self) -> Tuple[int, int, int]:
        test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8)
        logger.debug(f"Getting model output shape. Image size: {test_image.shape}")
        test_image, _ = self.preprocess(test_image)
        output = self.predict(test_image)[0]
        logger.debug(f"Model output shape test finished.")
        return output.shape

    def validate_model_classes(self) -> None:
        pass

    def get_infer_bucket_file_list(self) -> list:
        """Returns the list of files to be downloaded from the inference bucket for ONNX model.

        Returns:
            list: A list of filenames specific to ONNX models.
        """
        return ["environment.json", "class_names.txt"]

    def initialize_model(self, **kwargs) -> None:
        """Initializes the ONNX model, setting up the inference session and other necessary properties."""
        logger.debug("Getting model artefacts")
        self.get_model_artifacts(**kwargs)
        logger.debug("Creating inference session")
        if self.load_weights or not self.has_model_metadata:
            t1_session = perf_counter()
            # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
            providers = self.onnxruntime_execution_providers

            if not self.load_weights:
                providers = ["OpenVINOExecutionProvider", "CPUExecutionProvider"]
            try:
                session_options = onnxruntime.SessionOptions()
                session_options.log_severity_level = 3
                # TensorRT does better graph optimization for its EP than onnx
                if has_trt(providers):
                    session_options.graph_optimization_level = (
                        onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
                    )
                self.onnx_session = onnxruntime.InferenceSession(
                    self.cache_file(self.weights_file),
                    providers=providers,
                    sess_options=session_options,
                )
            except Exception as e:
                self.clear_cache(delete_from_disk=DISK_CACHE_CLEANUP)
                raise ModelArtefactError(
                    f"Unable to load ONNX session. Cause: {e}"
                ) from e
            logger.debug(f"Session created in {perf_counter() - t1_session} seconds")

            if REQUIRED_ONNX_PROVIDERS:
                available_providers = onnxruntime.get_available_providers()
                for provider in REQUIRED_ONNX_PROVIDERS:
                    if provider not in available_providers:
                        raise OnnxProviderNotAvailable(
                            f"Required ONNX Execution Provider {provider} is not availble. "
                            "Check that you are using the correct docker image on a supported device. "
                            "Export list of available providers as ONNXRUNTIME_EXECUTION_PROVIDERS environmental variable, "
                            "consult documentation for more details."
                        )

            inputs = self.onnx_session.get_inputs()[0]
            input_shape = inputs.shape
            self.batch_size = input_shape[0]
            self.img_size_h = input_shape[2]
            self.img_size_w = input_shape[3]
            self.input_name = inputs.name
            if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str):
                if "resize" in self.preproc:
                    self.img_size_h = int(self.preproc["resize"]["height"])
                    self.img_size_w = int(self.preproc["resize"]["width"])
                else:
                    self.img_size_h = 640
                    self.img_size_w = 640

            if isinstance(self.batch_size, str):
                self.batching_enabled = True
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching enabled"
                )
            else:
                self.batching_enabled = False
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching disabled"
                )

            model_metadata = {
                "batch_size": self.batch_size,
                "img_size_h": self.img_size_h,
                "img_size_w": self.img_size_w,
            }
            logger.debug(f"Writing model metadata to memcache")
            self.write_model_metadata_to_memcache(model_metadata)
            if not self.load_weights:  # had to load weights to get metadata
                del self.onnx_session
        else:
            if not self.has_model_metadata:
                raise ValueError(
                    "This should be unreachable, should get weights if we don't have model metadata"
                )
            logger.debug(f"Loading model metadata from memcache")
            metadata = self.model_metadata_from_memcache()
            self.batch_size = metadata["batch_size"]
            self.img_size_h = metadata["img_size_h"]
            self.img_size_w = metadata["img_size_w"]
            if isinstance(self.batch_size, str):
                self.batching_enabled = True
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching enabled"
                )
            else:
                self.batching_enabled = False
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching disabled"
                )

        logger.debug("Model initialisation finished.")

    def load_image(
        self,
        image: Any,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
    ) -> Tuple[np.ndarray, Tuple[Tuple[int, int], ...]]:
        if isinstance(image, list) and len(image) > 1:
            preproc_image = partial(
                self.preproc_image,
                disable_preproc_auto_orient=disable_preproc_auto_orient,
                disable_preproc_contrast=disable_preproc_contrast,
                disable_preproc_grayscale=disable_preproc_grayscale,
                disable_preproc_static_crop=disable_preproc_static_crop,
            )
            imgs_with_dims = self.image_loader_threadpool.map(preproc_image, image)
            imgs, img_dims = zip(*imgs_with_dims)
            if isinstance(imgs[0], np.ndarray):
                img_in = np.concatenate(imgs, axis=0)
            elif USE_PYTORCH_FOR_PREPROCESSING:
                img_in = torch.cat(imgs, dim=0)
            else:
                raise ValueError(
                    f"Received a list of images of unknown type, {type(imgs[0])}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )
        else:
            if isinstance(image, list):
                image = image[0]
            img_in, img_dims = self.preproc_image(
                image,
                disable_preproc_auto_orient=disable_preproc_auto_orient,
                disable_preproc_contrast=disable_preproc_contrast,
                disable_preproc_grayscale=disable_preproc_grayscale,
                disable_preproc_static_crop=disable_preproc_static_crop,
            )
            img_dims = (img_dims,)
        return img_in, img_dims

    @property
    def weights_file(self) -> str:
        """Returns the file containing the ONNX model weights.

        Returns:
            str: The file path to the weights file.
        """
        return "weights.onnx"
Attributes
weights_file property
weights_file

Returns the file containing the ONNX model weights.

Returns:

Name Type Description
str str

The file path to the weights file.

Functions
__init__
__init__(
    model_id,
    onnxruntime_execution_providers=get_onnxruntime_execution_providers(
        ONNXRUNTIME_EXECUTION_PROVIDERS
    ),
    *args,
    **kwargs
)

Initializes the OnnxRoboflowInferenceModel instance.

Parameters:

Name Type Description Default
model_id str

The identifier for the specific ONNX model.

required
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/core/models/roboflow.py
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
def __init__(
    self,
    model_id: str,
    onnxruntime_execution_providers: List[
        str
    ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS),
    *args,
    **kwargs,
):
    """Initializes the OnnxRoboflowInferenceModel instance.

    Args:
        model_id (str): The identifier for the specific ONNX model.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(model_id, *args, **kwargs)
    if self.load_weights or not self.has_model_metadata:
        self.onnxruntime_execution_providers = onnxruntime_execution_providers
        expanded_execution_providers = []
        for ep in self.onnxruntime_execution_providers:
            if ep == "TensorrtExecutionProvider":
                ep = (
                    "TensorrtExecutionProvider",
                    {
                        "trt_engine_cache_enable": True,
                        "trt_engine_cache_path": os.path.join(
                            TENSORRT_CACHE_PATH, self.endpoint
                        ),
                        "trt_fp16_enable": True,
                    },
                )
            expanded_execution_providers.append(ep)
        self.onnxruntime_execution_providers = expanded_execution_providers

    self.image_loader_threadpool = ThreadPoolExecutor(max_workers=None)
    self._session_lock = Lock()
    try:
        self.initialize_model(**kwargs)
        self.validate_model()
    except ModelArtefactError as e:
        logger.error(f"Unable to validate model artifacts, clearing cache: {e}")
        if DISK_CACHE_CLEANUP:
            self.clear_cache(delete_from_disk=True)
        else:
            logger.error("NOT deleting model from cache, inspect model artifacts")
        raise ModelArtefactError from e
get_infer_bucket_file_list
get_infer_bucket_file_list()

Returns the list of files to be downloaded from the inference bucket for ONNX model.

Returns:

Name Type Description
list list

A list of filenames specific to ONNX models.

Source code in inference/core/models/roboflow.py
899
900
901
902
903
904
905
def get_infer_bucket_file_list(self) -> list:
    """Returns the list of files to be downloaded from the inference bucket for ONNX model.

    Returns:
        list: A list of filenames specific to ONNX models.
    """
    return ["environment.json", "class_names.txt"]
infer
infer(image, **kwargs)

Runs inference on given data. - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Source code in inference/core/models/roboflow.py
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def infer(self, image: Any, **kwargs) -> Any:
    """Runs inference on given data.
    - image:
        can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
    """
    input_elements = len(image) if isinstance(image, list) else 1
    max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size
    if (input_elements == 1) or (max_batch_size == float("inf")):
        return super().infer(image, **kwargs)
    logger.debug(
        f"Inference will be executed in batches, as there is {input_elements} input elements and "
        f"maximum batch size for a model is set to: {max_batch_size}"
    )
    inference_results = []
    for batch_input in create_batches(sequence=image, batch_size=max_batch_size):
        batch_inference_results = super().infer(batch_input, **kwargs)
        inference_results.append(batch_inference_results)
    return self.merge_inference_results(inference_results=inference_results)
initialize_model
initialize_model(**kwargs)

Initializes the ONNX model, setting up the inference session and other necessary properties.

Source code in inference/core/models/roboflow.py
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
def initialize_model(self, **kwargs) -> None:
    """Initializes the ONNX model, setting up the inference session and other necessary properties."""
    logger.debug("Getting model artefacts")
    self.get_model_artifacts(**kwargs)
    logger.debug("Creating inference session")
    if self.load_weights or not self.has_model_metadata:
        t1_session = perf_counter()
        # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
        providers = self.onnxruntime_execution_providers

        if not self.load_weights:
            providers = ["OpenVINOExecutionProvider", "CPUExecutionProvider"]
        try:
            session_options = onnxruntime.SessionOptions()
            session_options.log_severity_level = 3
            # TensorRT does better graph optimization for its EP than onnx
            if has_trt(providers):
                session_options.graph_optimization_level = (
                    onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
                )
            self.onnx_session = onnxruntime.InferenceSession(
                self.cache_file(self.weights_file),
                providers=providers,
                sess_options=session_options,
            )
        except Exception as e:
            self.clear_cache(delete_from_disk=DISK_CACHE_CLEANUP)
            raise ModelArtefactError(
                f"Unable to load ONNX session. Cause: {e}"
            ) from e
        logger.debug(f"Session created in {perf_counter() - t1_session} seconds")

        if REQUIRED_ONNX_PROVIDERS:
            available_providers = onnxruntime.get_available_providers()
            for provider in REQUIRED_ONNX_PROVIDERS:
                if provider not in available_providers:
                    raise OnnxProviderNotAvailable(
                        f"Required ONNX Execution Provider {provider} is not availble. "
                        "Check that you are using the correct docker image on a supported device. "
                        "Export list of available providers as ONNXRUNTIME_EXECUTION_PROVIDERS environmental variable, "
                        "consult documentation for more details."
                    )

        inputs = self.onnx_session.get_inputs()[0]
        input_shape = inputs.shape
        self.batch_size = input_shape[0]
        self.img_size_h = input_shape[2]
        self.img_size_w = input_shape[3]
        self.input_name = inputs.name
        if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str):
            if "resize" in self.preproc:
                self.img_size_h = int(self.preproc["resize"]["height"])
                self.img_size_w = int(self.preproc["resize"]["width"])
            else:
                self.img_size_h = 640
                self.img_size_w = 640

        if isinstance(self.batch_size, str):
            self.batching_enabled = True
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching enabled"
            )
        else:
            self.batching_enabled = False
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching disabled"
            )

        model_metadata = {
            "batch_size": self.batch_size,
            "img_size_h": self.img_size_h,
            "img_size_w": self.img_size_w,
        }
        logger.debug(f"Writing model metadata to memcache")
        self.write_model_metadata_to_memcache(model_metadata)
        if not self.load_weights:  # had to load weights to get metadata
            del self.onnx_session
    else:
        if not self.has_model_metadata:
            raise ValueError(
                "This should be unreachable, should get weights if we don't have model metadata"
            )
        logger.debug(f"Loading model metadata from memcache")
        metadata = self.model_metadata_from_memcache()
        self.batch_size = metadata["batch_size"]
        self.img_size_h = metadata["img_size_h"]
        self.img_size_w = metadata["img_size_w"]
        if isinstance(self.batch_size, str):
            self.batching_enabled = True
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching enabled"
            )
        else:
            self.batching_enabled = False
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching disabled"
            )

    logger.debug("Model initialisation finished.")

RoboflowCoreModel

Bases: RoboflowInferenceModel

Base Roboflow inference model (Inherits from CvModel since all Roboflow models are CV models currently).

Source code in inference/core/models/roboflow.py
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
class RoboflowCoreModel(RoboflowInferenceModel):
    """Base Roboflow inference model (Inherits from CvModel since all Roboflow models are CV models currently)."""

    def __init__(
        self,
        model_id: str,
        api_key=None,
        **kwargs,
    ):
        """Initializes the RoboflowCoreModel instance.

        Args:
            model_id (str): The identifier for the specific model.
            api_key ([type], optional): The API key for authentication. Defaults to None.
        """
        super().__init__(model_id, api_key=api_key, **kwargs)
        self.download_weights()

    def download_weights(self) -> None:
        """Downloads the model weights from the configured source.

        This method includes handling for AWS access keys and error handling.
        """
        if MODELS_CACHE_AUTH_ENABLED:
            if not _check_if_api_key_has_access_to_model(
                api_key=self.api_key,
                model_id=self.endpoint,
                endpoint_type=ModelEndpointType.CORE_MODEL,
            ):
                raise RoboflowAPINotAuthorizedError(
                    f"API key {self.api_key} does not have access to model {self.endpoint}"
                )
        infer_bucket_files = self.get_infer_bucket_file_list()
        if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint):
            logger.debug("Model artifacts already downloaded, loading from cache")
            return None
        if is_model_artefacts_bucket_available():
            self.download_model_artefacts_from_s3()
            return None
        self.download_model_from_roboflow_api()

    def download_model_from_roboflow_api(self) -> None:

        # Use the same lock file pattern as in clear_cache
        lock_dir = MODEL_CACHE_DIR + "/_file_locks"  # Dedicated lock directory
        os.makedirs(lock_dir, exist_ok=True)  # Ensure lock directory exists.
        lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
        try:
            lock = FileLock(lock_file, timeout=120)  # 120 second timeout for downloads
            with lock:
                api_data = get_roboflow_model_data(
                    api_key=self.api_key,
                    model_id=self.endpoint,
                    endpoint_type=ModelEndpointType.CORE_MODEL,
                    device_id=self.device_id,
                )
                if "weights" not in api_data:
                    raise ModelArtefactError(
                        f"`weights` key not available in Roboflow API response while downloading model weights."
                    )
                for weights_url_key in api_data["weights"]:
                    weights_url = api_data["weights"][weights_url_key]
                    t1 = perf_counter()
                    model_weights_response = get_from_url(
                        weights_url, json_response=False
                    )
                    filename = weights_url.split("?")[0].split("/")[-1]
                    save_bytes_in_cache(
                        content=model_weights_response.content,
                        file=filename,
                        model_id=self.endpoint,
                    )
                    if perf_counter() - t1 > 120:
                        logger.debug(
                            "Weights download took longer than 120 seconds, refreshing API request"
                        )
                        api_data = get_roboflow_model_data(
                            api_key=self.api_key,
                            model_id=self.endpoint,
                            endpoint_type=ModelEndpointType.CORE_MODEL,
                            device_id=self.device_id,
                        )
        except Exception as e:
            logger.error(f"Error downloading model artifacts: {e}")
            raise

    def get_device_id(self) -> str:
        """Returns the device ID associated with this model.

        Returns:
            str: The device ID.
        """
        return self.device_id

    def get_infer_bucket_file_list(self) -> List[str]:
        """Abstract method to get the list of files to be downloaded from the inference bucket.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.

        Returns:
            List[str]: A list of filenames.
        """
        raise NotImplementedError(
            "get_infer_bucket_file_list not implemented for RoboflowCoreModel"
        )

    def preprocess_image(self, image: Image.Image) -> Image.Image:
        """Abstract method to preprocess an image.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.

        Returns:
            Image.Image: The preprocessed PIL image.
        """
        raise NotImplementedError(self.__class__.__name__ + ".preprocess_image")

    @property
    def weights_file(self) -> str:
        """Abstract property representing the file containing the model weights. For core models, all model artifacts are handled through get_infer_bucket_file_list method."""
        return None

    @property
    def model_artifact_bucket(self):
        return CORE_MODEL_BUCKET
Attributes
weights_file property
weights_file

Abstract property representing the file containing the model weights. For core models, all model artifacts are handled through get_infer_bucket_file_list method.

Functions
__init__
__init__(model_id, api_key=None, **kwargs)

Initializes the RoboflowCoreModel instance.

Parameters:

Name Type Description Default
model_id str

The identifier for the specific model.

required
api_key [type]

The API key for authentication. Defaults to None.

None
Source code in inference/core/models/roboflow.py
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def __init__(
    self,
    model_id: str,
    api_key=None,
    **kwargs,
):
    """Initializes the RoboflowCoreModel instance.

    Args:
        model_id (str): The identifier for the specific model.
        api_key ([type], optional): The API key for authentication. Defaults to None.
    """
    super().__init__(model_id, api_key=api_key, **kwargs)
    self.download_weights()
download_weights
download_weights()

Downloads the model weights from the configured source.

This method includes handling for AWS access keys and error handling.

Source code in inference/core/models/roboflow.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
def download_weights(self) -> None:
    """Downloads the model weights from the configured source.

    This method includes handling for AWS access keys and error handling.
    """
    if MODELS_CACHE_AUTH_ENABLED:
        if not _check_if_api_key_has_access_to_model(
            api_key=self.api_key,
            model_id=self.endpoint,
            endpoint_type=ModelEndpointType.CORE_MODEL,
        ):
            raise RoboflowAPINotAuthorizedError(
                f"API key {self.api_key} does not have access to model {self.endpoint}"
            )
    infer_bucket_files = self.get_infer_bucket_file_list()
    if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint):
        logger.debug("Model artifacts already downloaded, loading from cache")
        return None
    if is_model_artefacts_bucket_available():
        self.download_model_artefacts_from_s3()
        return None
    self.download_model_from_roboflow_api()
get_device_id
get_device_id()

Returns the device ID associated with this model.

Returns:

Name Type Description
str str

The device ID.

Source code in inference/core/models/roboflow.py
716
717
718
719
720
721
722
def get_device_id(self) -> str:
    """Returns the device ID associated with this model.

    Returns:
        str: The device ID.
    """
    return self.device_id
get_infer_bucket_file_list
get_infer_bucket_file_list()

Abstract method to get the list of files to be downloaded from the inference bucket.

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Returns:

Type Description
List[str]

List[str]: A list of filenames.

Source code in inference/core/models/roboflow.py
724
725
726
727
728
729
730
731
732
733
734
735
def get_infer_bucket_file_list(self) -> List[str]:
    """Abstract method to get the list of files to be downloaded from the inference bucket.

    Raises:
        NotImplementedError: This method must be implemented in subclasses.

    Returns:
        List[str]: A list of filenames.
    """
    raise NotImplementedError(
        "get_infer_bucket_file_list not implemented for RoboflowCoreModel"
    )
preprocess_image
preprocess_image(image)

Abstract method to preprocess an image.

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Returns:

Type Description
Image

Image.Image: The preprocessed PIL image.

Source code in inference/core/models/roboflow.py
737
738
739
740
741
742
743
744
745
746
def preprocess_image(self, image: Image.Image) -> Image.Image:
    """Abstract method to preprocess an image.

    Raises:
        NotImplementedError: This method must be implemented in subclasses.

    Returns:
        Image.Image: The preprocessed PIL image.
    """
    raise NotImplementedError(self.__class__.__name__ + ".preprocess_image")

RoboflowInferenceModel

Bases: Model

Base Roboflow inference model.

Source code in inference/core/models/roboflow.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
class RoboflowInferenceModel(Model):
    """Base Roboflow inference model."""

    def __init__(
        self,
        model_id: str,
        cache_dir_root=MODEL_CACHE_DIR,
        api_key=None,
        load_weights=True,
        **kwargs,
    ):
        """
        Initialize the RoboflowInferenceModel object.

        Args:
            model_id (str): The unique identifier for the model.
            cache_dir_root (str, optional): The root directory for the cache. Defaults to MODEL_CACHE_DIR.
            api_key (str, optional): API key for authentication. Defaults to None.
        """
        super().__init__()
        self.load_weights = load_weights
        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
        self.api_key = api_key if api_key else API_KEY
        model_id = resolve_roboflow_model_alias(model_id=model_id)
        self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
        self.endpoint = model_id
        self.device_id = GLOBAL_DEVICE_ID
        self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
        self.keypoints_metadata: Optional[dict] = None
        initialise_cache(model_id=self.endpoint)

    def cache_file(self, f: str) -> str:
        """Get the cache file path for a given file.

        Args:
            f (str): Filename.

        Returns:
            str: Full path to the cached file.
        """
        return get_cache_file_path(file=f, model_id=self.endpoint)

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        """Clear the cache directory.

        Args:
            delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
        """
        clear_cache(model_id=self.endpoint, delete_from_disk=delete_from_disk)

    def draw_predictions(
        self,
        inference_request: InferenceRequest,
        inference_response: InferenceResponse,
    ) -> bytes:
        """Draw predictions from an inference response onto the original image provided by an inference request

        Args:
            inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
            inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

        Returns:
            str: A base64 encoded image string
        """
        return draw_detection_predictions(
            inference_request=inference_request,
            inference_response=inference_response,
            colors=self.colors,
        )

    @property
    def get_class_names(self):
        return self.class_names

    def get_device_id(self) -> str:
        """
        Get the device identifier on which the model is deployed.

        Returns:
            str: Device identifier.
        """
        return self.device_id

    def get_infer_bucket_file_list(self) -> List[str]:
        """Get a list of inference bucket files.

        Raises:
            NotImplementedError: If the method is not implemented.

        Returns:
            List[str]: A list of inference bucket files.
        """
        raise NotImplementedError(
            self.__class__.__name__ + ".get_infer_bucket_file_list"
        )

    @property
    def cache_key(self):
        return f"metadata:{self.endpoint}"

    @staticmethod
    def model_metadata_from_memcache_endpoint(endpoint):
        model_metadata = cache.get(f"metadata:{endpoint}")
        return model_metadata

    def model_metadata_from_memcache(self):
        model_metadata = cache.get(self.cache_key)
        return model_metadata

    def write_model_metadata_to_memcache(self, metadata):
        cache.set(
            self.cache_key, metadata, expire=MODEL_METADATA_CACHE_EXPIRATION_TIMEOUT
        )

    @property
    def has_model_metadata(self):
        return self.model_metadata_from_memcache() is not None

    def get_model_artifacts(
        self,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
        **kwargs,
    ) -> None:
        """Fetch or load the model artifacts.

        Downloads the model artifacts from S3 or the Roboflow API if they are not already cached.
        """
        if MODELS_CACHE_AUTH_ENABLED:
            if not _check_if_api_key_has_access_to_model(
                api_key=self.api_key,
                model_id=self.endpoint,
                endpoint_type=ModelEndpointType.ORT,
                countinference=countinference,
                service_secret=service_secret,
            ):
                raise RoboflowAPINotAuthorizedError(
                    f"API key {self.api_key} does not have access to model {self.endpoint}"
                )
        self.cache_model_artefacts(
            countinference=countinference,
            service_secret=service_secret,
            **kwargs,
        )
        self.load_model_artifacts_from_cache()

    def cache_model_artefacts(
        self,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
        **kwargs,
    ) -> None:
        infer_bucket_files = self.get_all_required_infer_bucket_file()

        if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint):
            return None
        if is_model_artefacts_bucket_available():
            self.download_model_artefacts_from_s3()
            return None
        self.download_model_artifacts_from_roboflow_api(
            countinference=countinference,
            service_secret=service_secret,
            **kwargs,
        )

    def get_all_required_infer_bucket_file(self) -> List[str]:
        infer_bucket_files = self.get_infer_bucket_file_list()
        infer_bucket_files.append(self.weights_file)
        logger.debug(f"List of files required to load model: {infer_bucket_files}")
        return [f for f in infer_bucket_files if f is not None]

    def download_model_artefacts_from_s3(self) -> None:
        try:
            logger.debug("Downloading model artifacts from S3")
            infer_bucket_files = self.get_all_required_infer_bucket_file()
            cache_directory = get_cache_dir()
            s3_keys = [f"{self.endpoint}/{file}" for file in infer_bucket_files]
            download_s3_files_to_directory(
                bucket=self.model_artifact_bucket,
                keys=s3_keys,
                target_dir=cache_directory,
                s3_client=S3_CLIENT,
            )
        except Exception as error:
            raise ModelArtefactError(
                f"Could not obtain model artefacts from S3 with keys {s3_keys}. Cause: {error}"
            ) from error

    @property
    def model_artifact_bucket(self):
        return INFER_BUCKET

    def download_model_artifacts_from_roboflow_api(
        self,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
        **kwargs,
    ) -> None:
        logger.debug("Downloading model artifacts from Roboflow API")

        # Use the same lock file pattern as in clear_cache
        lock_dir = MODEL_CACHE_DIR + "/_file_locks"  # Dedicated lock directory
        os.makedirs(lock_dir, exist_ok=True)  # Ensure lock directory exists.
        lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
        try:
            lock = FileLock(lock_file, timeout=120)  # 120 second timeout for downloads
            with lock:
                if self.version_id is not None:
                    api_data = get_roboflow_model_data(
                        api_key=self.api_key,
                        model_id=self.endpoint,
                        endpoint_type=ModelEndpointType.ORT,
                        device_id=self.device_id,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    if "ort" not in api_data.keys():
                        raise ModelArtefactError(
                            "Could not find `ort` key in roboflow API model description response."
                        )
                    api_data = api_data["ort"]
                    if "classes" in api_data:
                        save_text_lines_in_cache(
                            content=api_data["classes"],
                            file="class_names.txt",
                            model_id=self.endpoint,
                        )
                    if "model" not in api_data:
                        raise ModelArtefactError(
                            "Could not find `model` key in roboflow API model description response."
                        )
                    if "environment" not in api_data:
                        raise ModelArtefactError(
                            "Could not find `environment` key in roboflow API model description response."
                        )
                    environment = get_from_url(api_data["environment"])
                    model_weights_response = get_from_url(
                        api_data["model"],
                        json_response=False,
                    )
                else:
                    api_data = get_roboflow_instant_model_data(
                        api_key=self.api_key,
                        model_id=self.endpoint,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    if (
                        "modelFiles" not in api_data
                        or "ort" not in api_data["modelFiles"]
                        or "model" not in api_data["modelFiles"]["ort"]
                    ):
                        raise ModelArtefactError(
                            "Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
                        )
                    if "environment" not in api_data:
                        raise ModelArtefactError(
                            "Could not find `environment` key in roboflow API model description response."
                        )
                    model_weights_response = get_from_url(
                        api_data["modelFiles"]["ort"]["model"],
                        json_response=False,
                    )
                    environment = api_data["environment"]
                    if "classes" in api_data:
                        save_text_lines_in_cache(
                            content=api_data["classes"],
                            file="class_names.txt",
                            model_id=self.endpoint,
                        )

                save_bytes_in_cache(
                    content=model_weights_response.content,
                    file=self.weights_file,
                    model_id=self.endpoint,
                )
                if "colors" in api_data:
                    environment["COLORS"] = api_data["colors"]
                save_json_in_cache(
                    content=environment,
                    file="environment.json",
                    model_id=self.endpoint,
                )
                if "keypoints_metadata" in api_data:
                    # TODO: make sure backend provides that
                    save_json_in_cache(
                        content=api_data["keypoints_metadata"],
                        file="keypoints_metadata.json",
                        model_id=self.endpoint,
                    )
        except Exception as e:
            logger.error(f"Error downloading model artifacts: {e}")
            raise

    def load_model_artifacts_from_cache(self) -> None:
        logger.debug("Model artifacts already downloaded, loading model from cache")
        infer_bucket_files = self.get_all_required_infer_bucket_file()
        if "environment.json" in infer_bucket_files:
            self.environment = load_json_from_cache(
                file="environment.json",
                model_id=self.endpoint,
                object_pairs_hook=OrderedDict,
            )
        if "class_names.txt" in infer_bucket_files:
            self.class_names = load_text_file_from_cache(
                file="class_names.txt",
                model_id=self.endpoint,
                split_lines=True,
                strip_white_chars=True,
            )
        else:
            self.class_names = get_class_names_from_environment_file(
                environment=self.environment
            )
        self.colors = get_color_mapping_from_environment(
            environment=self.environment,
            class_names=self.class_names,
        )
        if "keypoints_metadata.json" in infer_bucket_files:
            self.keypoints_metadata = parse_keypoints_metadata(
                load_json_from_cache(
                    file="keypoints_metadata.json",
                    model_id=self.endpoint,
                    object_pairs_hook=OrderedDict,
                )
            )
        self.num_classes = len(self.class_names)
        if "PREPROCESSING" not in self.environment:
            raise ModelArtefactError(
                "Could not find `PREPROCESSING` key in environment file."
            )
        if issubclass(type(self.environment["PREPROCESSING"]), dict):
            self.preproc = self.environment["PREPROCESSING"]
        else:
            self.preproc = json.loads(self.environment["PREPROCESSING"])
        if self.preproc.get("resize"):
            self.resize_method = self.preproc["resize"].get("format", "Stretch to")
            if self.resize_method in [
                "Fit (reflect edges) in",
                "Fit within",
                "Fill (with center crop) in",
            ]:
                fallback_resize_method = "Fit (black edges) in"
                logger.warning(
                    "Unsupported resize method '%s', defaulting to '%s' - this may result in degraded model performance.",
                    self.resize_method,
                    fallback_resize_method,
                )
                self.resize_method = fallback_resize_method
            if self.resize_method not in [
                "Stretch to",
                "Fit (black edges) in",
                "Fit (grey edges) in",
                "Fit (white edges) in",
            ]:
                logger.error(
                    "Unsupported resize method '%s', defaulting to 'Stretch to' - this may result in degraded model performance.",
                    self.resize_method,
                )
                self.resize_method = "Stretch to"
        else:
            logger.error(
                "Unknown resize method, defaulting to 'Stretch to' - this may result in degraded model performance."
            )
            self.resize_method = "Stretch to"
        logger.debug(f"Resize method is '{self.resize_method}'")
        self.multiclass = self.environment.get("MULTICLASS", False)

    def initialize_model(self, **kwargs) -> None:
        """Initialize the model.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError(self.__class__.__name__ + ".initialize_model")

    def preproc_image(
        self,
        image: Union[Any, InferenceRequestImage],
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
    ) -> Tuple[np.ndarray, Tuple[int, int]]:
        """
        Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

        Args:
            image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference.
            disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
            disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.
        """
        np_image, is_bgr = load_image(
            image,
            disable_preproc_auto_orient=disable_preproc_auto_orient
            or "auto-orient" not in self.preproc.keys()
            or DISABLE_PREPROC_AUTO_ORIENT,
        )
        preprocessed_image, img_dims = self.preprocess_image(
            np_image,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )

        if USE_PYTORCH_FOR_PREPROCESSING:
            preprocessed_image = torch.from_numpy(
                np.ascontiguousarray(preprocessed_image)
            )
            if torch.cuda.is_available():
                preprocessed_image = preprocessed_image.cuda()
            preprocessed_image = (
                preprocessed_image.permute(2, 0, 1).unsqueeze(0).contiguous().float()
            )
        if self.resize_method == "Stretch to":
            if isinstance(preprocessed_image, np.ndarray):
                preprocessed_image = preprocessed_image.astype(np.float32)
                resized = cv2.resize(
                    preprocessed_image,
                    (self.img_size_w, self.img_size_h),
                )
            elif USE_PYTORCH_FOR_PREPROCESSING:
                resized = torch.nn.functional.interpolate(
                    preprocessed_image,
                    size=(self.img_size_h, self.img_size_w),
                    mode="bilinear",
                )
            else:
                raise ValueError(
                    f"Received an image of unknown type, {type(preprocessed_image)}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )

        elif self.resize_method == "Fit (black edges) in":
            resized = letterbox_image(
                preprocessed_image, (self.img_size_w, self.img_size_h)
            )
        elif self.resize_method == "Fit (white edges) in":
            resized = letterbox_image(
                preprocessed_image,
                (self.img_size_w, self.img_size_h),
                color=(255, 255, 255),
            )
        elif self.resize_method == "Fit (grey edges) in":
            resized = letterbox_image(
                preprocessed_image,
                (self.img_size_w, self.img_size_h),
                color=(114, 114, 114),
            )

        if is_bgr:
            if isinstance(resized, np.ndarray):
                resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            else:
                resized = resized[:, [2, 1, 0], :, :]

        if isinstance(resized, np.ndarray):
            img_in = np.transpose(resized, (2, 0, 1))
            img_in = img_in.astype(np.float32)
            img_in = np.expand_dims(img_in, axis=0)
        elif USE_PYTORCH_FOR_PREPROCESSING:
            img_in = resized.float()
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(resized)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

        return img_in, img_dims

    def preprocess_image(
        self,
        image: np.ndarray,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
    ) -> Tuple[np.ndarray, Tuple[int, int]]:
        """
        Preprocesses the given image using specified preprocessing steps.

        Args:
            image (Image.Image): The PIL image to preprocess.
            disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

        Returns:
            Image.Image: The preprocessed PIL image.
        """
        return prepare(
            image,
            self.preproc,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )

    @property
    def weights_file(self) -> str:
        """Abstract property representing the file containing the model weights.

        Raises:
            NotImplementedError: This property must be implemented in subclasses.

        Returns:
            str: The file path to the weights file.
        """
        raise NotImplementedError(self.__class__.__name__ + ".weights_file")
Attributes
weights_file property
weights_file

Abstract property representing the file containing the model weights.

Raises:

Type Description
NotImplementedError

This property must be implemented in subclasses.

Returns:

Name Type Description
str str

The file path to the weights file.

Functions
__init__
__init__(
    model_id,
    cache_dir_root=MODEL_CACHE_DIR,
    api_key=None,
    load_weights=True,
    **kwargs
)

Initialize the RoboflowInferenceModel object.

Parameters:

Name Type Description Default
model_id str

The unique identifier for the model.

required
cache_dir_root str

The root directory for the cache. Defaults to MODEL_CACHE_DIR.

MODEL_CACHE_DIR
api_key str

API key for authentication. Defaults to None.

None
Source code in inference/core/models/roboflow.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    model_id: str,
    cache_dir_root=MODEL_CACHE_DIR,
    api_key=None,
    load_weights=True,
    **kwargs,
):
    """
    Initialize the RoboflowInferenceModel object.

    Args:
        model_id (str): The unique identifier for the model.
        cache_dir_root (str, optional): The root directory for the cache. Defaults to MODEL_CACHE_DIR.
        api_key (str, optional): API key for authentication. Defaults to None.
    """
    super().__init__()
    self.load_weights = load_weights
    self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
    self.api_key = api_key if api_key else API_KEY
    model_id = resolve_roboflow_model_alias(model_id=model_id)
    self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
    self.endpoint = model_id
    self.device_id = GLOBAL_DEVICE_ID
    self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
    self.keypoints_metadata: Optional[dict] = None
    initialise_cache(model_id=self.endpoint)
cache_file
cache_file(f)

Get the cache file path for a given file.

Parameters:

Name Type Description Default
f str

Filename.

required

Returns:

Name Type Description
str str

Full path to the cached file.

Source code in inference/core/models/roboflow.py
144
145
146
147
148
149
150
151
152
153
def cache_file(self, f: str) -> str:
    """Get the cache file path for a given file.

    Args:
        f (str): Filename.

    Returns:
        str: Full path to the cached file.
    """
    return get_cache_file_path(file=f, model_id=self.endpoint)
clear_cache
clear_cache(delete_from_disk=True)

Clear the cache directory.

Parameters:

Name Type Description Default
delete_from_disk bool

Whether to delete cached files from disk. Defaults to True.

True
Source code in inference/core/models/roboflow.py
155
156
157
158
159
160
161
def clear_cache(self, delete_from_disk: bool = True) -> None:
    """Clear the cache directory.

    Args:
        delete_from_disk (bool, optional): Whether to delete cached files from disk. Defaults to True.
    """
    clear_cache(model_id=self.endpoint, delete_from_disk=delete_from_disk)
draw_predictions
draw_predictions(inference_request, inference_response)

Draw predictions from an inference response onto the original image provided by an inference request

Parameters:

Name Type Description Default
inference_request ObjectDetectionInferenceRequest

The inference request containing the image on which to draw predictions

required
inference_response ObjectDetectionInferenceResponse

The inference response containing predictions to be drawn

required

Returns:

Name Type Description
str bytes

A base64 encoded image string

Source code in inference/core/models/roboflow.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def draw_predictions(
    self,
    inference_request: InferenceRequest,
    inference_response: InferenceResponse,
) -> bytes:
    """Draw predictions from an inference response onto the original image provided by an inference request

    Args:
        inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
        inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

    Returns:
        str: A base64 encoded image string
    """
    return draw_detection_predictions(
        inference_request=inference_request,
        inference_response=inference_response,
        colors=self.colors,
    )
get_device_id
get_device_id()

Get the device identifier on which the model is deployed.

Returns:

Name Type Description
str str

Device identifier.

Source code in inference/core/models/roboflow.py
187
188
189
190
191
192
193
194
def get_device_id(self) -> str:
    """
    Get the device identifier on which the model is deployed.

    Returns:
        str: Device identifier.
    """
    return self.device_id
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get a list of inference bucket files.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Returns:

Type Description
List[str]

List[str]: A list of inference bucket files.

Source code in inference/core/models/roboflow.py
196
197
198
199
200
201
202
203
204
205
206
207
def get_infer_bucket_file_list(self) -> List[str]:
    """Get a list of inference bucket files.

    Raises:
        NotImplementedError: If the method is not implemented.

    Returns:
        List[str]: A list of inference bucket files.
    """
    raise NotImplementedError(
        self.__class__.__name__ + ".get_infer_bucket_file_list"
    )
get_model_artifacts
get_model_artifacts(
    countinference=None, service_secret=None, **kwargs
)

Fetch or load the model artifacts.

Downloads the model artifacts from S3 or the Roboflow API if they are not already cached.

Source code in inference/core/models/roboflow.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def get_model_artifacts(
    self,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
    **kwargs,
) -> None:
    """Fetch or load the model artifacts.

    Downloads the model artifacts from S3 or the Roboflow API if they are not already cached.
    """
    if MODELS_CACHE_AUTH_ENABLED:
        if not _check_if_api_key_has_access_to_model(
            api_key=self.api_key,
            model_id=self.endpoint,
            endpoint_type=ModelEndpointType.ORT,
            countinference=countinference,
            service_secret=service_secret,
        ):
            raise RoboflowAPINotAuthorizedError(
                f"API key {self.api_key} does not have access to model {self.endpoint}"
            )
    self.cache_model_artefacts(
        countinference=countinference,
        service_secret=service_secret,
        **kwargs,
    )
    self.load_model_artifacts_from_cache()
initialize_model
initialize_model(**kwargs)

Initialize the model.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in inference/core/models/roboflow.py
481
482
483
484
485
486
487
def initialize_model(self, **kwargs) -> None:
    """Initialize the model.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError(self.__class__.__name__ + ".initialize_model")
preproc_image
preproc_image(
    image,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
)

Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

Parameters:

Name Type Description Default
image Union[Any, InferenceRequestImage]

An object containing information necessary to load the image for inference.

required
disable_preproc_auto_orient bool

If true, the auto orient preprocessing step is disabled for this call. Default is False.

False
disable_preproc_contrast bool

If true, the contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False

Returns:

Type Description
Tuple[ndarray, Tuple[int, int]]

Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.

Source code in inference/core/models/roboflow.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def preproc_image(
    self,
    image: Union[Any, InferenceRequestImage],
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
) -> Tuple[np.ndarray, Tuple[int, int]]:
    """
    Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

    Args:
        image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference.
        disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.
    """
    np_image, is_bgr = load_image(
        image,
        disable_preproc_auto_orient=disable_preproc_auto_orient
        or "auto-orient" not in self.preproc.keys()
        or DISABLE_PREPROC_AUTO_ORIENT,
    )
    preprocessed_image, img_dims = self.preprocess_image(
        np_image,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
    )

    if USE_PYTORCH_FOR_PREPROCESSING:
        preprocessed_image = torch.from_numpy(
            np.ascontiguousarray(preprocessed_image)
        )
        if torch.cuda.is_available():
            preprocessed_image = preprocessed_image.cuda()
        preprocessed_image = (
            preprocessed_image.permute(2, 0, 1).unsqueeze(0).contiguous().float()
        )
    if self.resize_method == "Stretch to":
        if isinstance(preprocessed_image, np.ndarray):
            preprocessed_image = preprocessed_image.astype(np.float32)
            resized = cv2.resize(
                preprocessed_image,
                (self.img_size_w, self.img_size_h),
            )
        elif USE_PYTORCH_FOR_PREPROCESSING:
            resized = torch.nn.functional.interpolate(
                preprocessed_image,
                size=(self.img_size_h, self.img_size_w),
                mode="bilinear",
            )
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(preprocessed_image)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

    elif self.resize_method == "Fit (black edges) in":
        resized = letterbox_image(
            preprocessed_image, (self.img_size_w, self.img_size_h)
        )
    elif self.resize_method == "Fit (white edges) in":
        resized = letterbox_image(
            preprocessed_image,
            (self.img_size_w, self.img_size_h),
            color=(255, 255, 255),
        )
    elif self.resize_method == "Fit (grey edges) in":
        resized = letterbox_image(
            preprocessed_image,
            (self.img_size_w, self.img_size_h),
            color=(114, 114, 114),
        )

    if is_bgr:
        if isinstance(resized, np.ndarray):
            resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
        else:
            resized = resized[:, [2, 1, 0], :, :]

    if isinstance(resized, np.ndarray):
        img_in = np.transpose(resized, (2, 0, 1))
        img_in = img_in.astype(np.float32)
        img_in = np.expand_dims(img_in, axis=0)
    elif USE_PYTORCH_FOR_PREPROCESSING:
        img_in = resized.float()
    else:
        raise ValueError(
            f"Received an image of unknown type, {type(resized)}; "
            "This is most likely a bug. Contact Roboflow team through github issues "
            "(https://github.com/roboflow/inference/issues) providing full context of the problem"
        )

    return img_in, img_dims
preprocess_image
preprocess_image(
    image,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
)

Preprocesses the given image using specified preprocessing steps.

Parameters:

Name Type Description Default
image Image

The PIL image to preprocess.

required
disable_preproc_contrast bool

If true, the contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False

Returns:

Type Description
Tuple[ndarray, Tuple[int, int]]

Image.Image: The preprocessed PIL image.

Source code in inference/core/models/roboflow.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def preprocess_image(
    self,
    image: np.ndarray,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
) -> Tuple[np.ndarray, Tuple[int, int]]:
    """
    Preprocesses the given image using specified preprocessing steps.

    Args:
        image (Image.Image): The PIL image to preprocess.
        disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

    Returns:
        Image.Image: The preprocessed PIL image.
    """
    return prepare(
        image,
        self.preproc,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
    )

Functions

core/models/utils

inference.core.models.utils.keypoints

Functions

superset_keypoints_count

superset_keypoints_count(keypoints_metadata={})

Returns the number of keypoints in the superset.

Source code in inference/core/models/utils/keypoints.py
 7
 8
 9
10
11
12
13
def superset_keypoints_count(keypoints_metadata={}) -> int:
    """Returns the number of keypoints in the superset."""
    max_keypoints = 0
    for keypoints in keypoints_metadata.values():
        if len(keypoints) > max_keypoints:
            max_keypoints = len(keypoints)
    return max_keypoints

core/registries

Model and block registries for dynamic lookup and plugin discovery.

inference.core.registries.base

Classes

ModelRegistry

An object which is able to return model classes based on given model IDs and model types.

Attributes:

Name Type Description
registry_dict dict

A dictionary mapping model types to model classes.

Source code in inference/core/registries/base.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ModelRegistry:
    """An object which is able to return model classes based on given model IDs and model types.

    Attributes:
        registry_dict (dict): A dictionary mapping model types to model classes.
    """

    def __init__(self, registry_dict) -> None:
        """Initializes the ModelRegistry with the given dictionary of registered models.

        Args:
            registry_dict (dict): A dictionary mapping model types to model classes.
        """
        self.registry_dict = registry_dict

    def get_model(
        self,
        model_type: str,
        model_id: str,
        **kwargs,
    ) -> Model:
        """Returns the model class based on the given model type.

        Args:
            model_type (str): The type of the model to be retrieved.
            model_id (str): The ID of the model to be retrieved (unused in the current implementation).

        Returns:
            Model: The model class corresponding to the given model type.

        Raises:
            ModelNotRecognisedError: If the model_type is not found in the registry_dict.
        """
        if model_type not in self.registry_dict:
            raise ModelNotRecognisedError(
                f"Could not find model of type: {model_type} in configured registry."
            )
        return self.registry_dict[model_type]
Functions
__init__
__init__(registry_dict)

Initializes the ModelRegistry with the given dictionary of registered models.

Parameters:

Name Type Description Default
registry_dict dict

A dictionary mapping model types to model classes.

required
Source code in inference/core/registries/base.py
14
15
16
17
18
19
20
def __init__(self, registry_dict) -> None:
    """Initializes the ModelRegistry with the given dictionary of registered models.

    Args:
        registry_dict (dict): A dictionary mapping model types to model classes.
    """
    self.registry_dict = registry_dict
get_model
get_model(model_type, model_id, **kwargs)

Returns the model class based on the given model type.

Parameters:

Name Type Description Default
model_type str

The type of the model to be retrieved.

required
model_id str

The ID of the model to be retrieved (unused in the current implementation).

required

Returns:

Name Type Description
Model Model

The model class corresponding to the given model type.

Raises:

Type Description
ModelNotRecognisedError

If the model_type is not found in the registry_dict.

Source code in inference/core/registries/base.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_model(
    self,
    model_type: str,
    model_id: str,
    **kwargs,
) -> Model:
    """Returns the model class based on the given model type.

    Args:
        model_type (str): The type of the model to be retrieved.
        model_id (str): The ID of the model to be retrieved (unused in the current implementation).

    Returns:
        Model: The model class corresponding to the given model type.

    Raises:
        ModelNotRecognisedError: If the model_type is not found in the registry_dict.
    """
    if model_type not in self.registry_dict:
        raise ModelNotRecognisedError(
            f"Could not find model of type: {model_type} in configured registry."
        )
    return self.registry_dict[model_type]

inference.core.registries.roboflow

Classes

RoboflowModelRegistry

Bases: ModelRegistry

A Roboflow-specific model registry which gets the model type using the model id, then returns a model class based on the model type.

Source code in inference/core/registries/roboflow.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class RoboflowModelRegistry(ModelRegistry):
    """A Roboflow-specific model registry which gets the model type using the model id,
    then returns a model class based on the model type.
    """

    def get_model(
        self,
        model_id: ModelID,
        api_key: str,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
    ) -> Model:
        """Returns the model class based on the given model id and API key.

        Args:
            model_id (str): The ID of the model to be retrieved.
            api_key (str): The API key used to authenticate.

        Returns:
            Model: The model class corresponding to the given model ID and type.

        Raises:
            ModelNotRecognisedError: If the model type is not supported or found.
        """
        model_type = get_model_type(
            model_id,
            api_key,
            countinference=countinference,
            service_secret=service_secret,
        )
        logger.debug(f"Model type: {model_type}")

        if model_type not in self.registry_dict:
            raise ModelNotRecognisedError(
                f"Model type not supported, you may want to try a different inference server configuration or endpoint: {model_type}"
            )
        return self.registry_dict[model_type]
Functions
get_model
get_model(
    model_id,
    api_key,
    countinference=None,
    service_secret=None,
)

Returns the model class based on the given model id and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model to be retrieved.

required
api_key str

The API key used to authenticate.

required

Returns:

Name Type Description
Model Model

The model class corresponding to the given model ID and type.

Raises:

Type Description
ModelNotRecognisedError

If the model type is not supported or found.

Source code in inference/core/registries/roboflow.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def get_model(
    self,
    model_id: ModelID,
    api_key: str,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
) -> Model:
    """Returns the model class based on the given model id and API key.

    Args:
        model_id (str): The ID of the model to be retrieved.
        api_key (str): The API key used to authenticate.

    Returns:
        Model: The model class corresponding to the given model ID and type.

    Raises:
        ModelNotRecognisedError: If the model type is not supported or found.
    """
    model_type = get_model_type(
        model_id,
        api_key,
        countinference=countinference,
        service_secret=service_secret,
    )
    logger.debug(f"Model type: {model_type}")

    if model_type not in self.registry_dict:
        raise ModelNotRecognisedError(
            f"Model type not supported, you may want to try a different inference server configuration or endpoint: {model_type}"
        )
    return self.registry_dict[model_type]

Functions

get_model_type

get_model_type(
    model_id,
    api_key=None,
    countinference=None,
    service_secret=None,
)

Retrieves the model type based on the given model ID and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model.

required
api_key str

The API key used to authenticate.

None

Returns:

Name Type Description
tuple Tuple[TaskType, ModelType]

The project task type and the model type.

Raises:

Type Description
WorkspaceLoadError

If the workspace could not be loaded or if the API key is invalid.

DatasetLoadError

If the dataset could not be loaded due to invalid ID, workspace ID or version ID.

MissingDefaultModelError

If default model is not configured and API does not provide this info

MalformedRoboflowAPIResponseError

Roboflow API responds in invalid format.

Source code in inference/core/registries/roboflow.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def get_model_type(
    model_id: ModelID,
    api_key: Optional[str] = None,
    countinference: Optional[bool] = None,
    service_secret: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
    """Retrieves the model type based on the given model ID and API key.

    Args:
        model_id (str): The ID of the model.
        api_key (str): The API key used to authenticate.

    Returns:
        tuple: The project task type and the model type.

    Raises:
        WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid.
        DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID.
        MissingDefaultModelError: If default model is not configured and API does not provide this info
        MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format.
    """

    model_id = resolve_roboflow_model_alias(model_id=model_id)
    dataset_id, version_id = get_model_id_chunks(model_id=model_id)
    # first check if the model id as a whole is in the GENERIC_MODELS dictionary
    if model_id in GENERIC_MODELS:
        logger.debug(f"Loading generic model: {model_id}.")
        return GENERIC_MODELS[model_id]

    # then check if the dataset id is in the GENERIC_MODELS dictionary
    if dataset_id in GENERIC_MODELS:
        logger.debug(f"Loading generic model: {dataset_id}.")
        return GENERIC_MODELS[dataset_id]

    if MODELS_CACHE_AUTH_ENABLED:
        if not _check_if_api_key_has_access_to_model(
            api_key=api_key,
            model_id=model_id,
            countinference=countinference,
            service_secret=service_secret,
        ):
            raise RoboflowAPINotAuthorizedError(
                f"API key {api_key} does not have access to model {model_id}"
            )

    cached_metadata = get_model_metadata_from_cache(
        dataset_id=dataset_id, version_id=version_id
    )

    if cached_metadata is not None:
        return cached_metadata[0], cached_metadata[1]
    if version_id == STUB_VERSION_ID:
        if api_key is None:
            raise MissingApiKeyError(
                "Stub model version provided but no API key was provided. API key is required to load stub models."
            )
        workspace_id = get_roboflow_workspace(api_key=api_key)
        project_task_type = get_roboflow_dataset_type(
            api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
        )
        model_type = "stub"
        save_model_metadata_in_cache(
            dataset_id=dataset_id,
            version_id=version_id,
            project_task_type=project_task_type,
            model_type=model_type,
        )
        return project_task_type, model_type

    if version_id is not None:
        api_data = get_roboflow_model_data(
            api_key=api_key,
            model_id=model_id,
            countinference=countinference,
            service_secret=service_secret,
            endpoint_type=ModelEndpointType.ORT,
            device_id=GLOBAL_DEVICE_ID,
        ).get("ort")
        project_task_type = api_data.get("type", "object-detection")
    elif not USE_INFERENCE_MODELS:
        api_data = get_roboflow_instant_model_data(
            api_key=api_key,
            model_id=model_id,
            countinference=countinference,
            service_secret=service_secret,
        )
        project_task_type = api_data.get("taskType", "object-detection")
    else:
        api_data = get_model_metadata_from_inference_models_registry(
            api_key=api_key,
            model_id=model_id,
            countinference=countinference,
            service_secret=service_secret,
        )
        project_task_type = api_data.get("taskType", "object-detection")
    if api_data is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")

    # some older projects do not have type field - hence defaulting
    model_type = api_data.get("modelType")
    if model_type is None or model_type == "ort":
        # some very old model versions do not have modelType reported - and API respond in a generic way -
        # then we shall attempt using default model for given task type
        model_type = MODEL_TYPE_DEFAULTS.get(project_task_type)

    if model_type is None or project_task_type is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
    save_model_metadata_in_cache(
        dataset_id=dataset_id,
        version_id=version_id,
        project_task_type=project_task_type,
        model_type=model_type,
    )

    return project_task_type, model_type

core/utils

General-purpose utilities: image encoding, file I/O, hashing, URL handling, and more.

inference.core.utils.container

Functions

is_docker_socket_mounted

is_docker_socket_mounted(docker_socket_path)

Check if the given path is a mounted Docker socket.

Parameters:

Name Type Description Default
docker_socket_path str

The path to the socket file.

required

Returns:

Name Type Description
bool bool

True if the path is a Unix socket, False otherwise.

Source code in inference/core/utils/container.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def is_docker_socket_mounted(docker_socket_path: str) -> bool:
    """
    Check if the given path is a mounted Docker socket.

    Args:
        docker_socket_path (str): The path to the socket file.

    Returns:
        bool: True if the path is a Unix socket, False otherwise.
    """
    if os.path.exists(docker_socket_path):
        socket_stat = os.stat(docker_socket_path)
        if stat.S_ISSOCK(socket_stat.st_mode):
            return True
    return False

inference.core.utils.environment

Classes

Functions

safe_env_to_type

safe_env_to_type(
    variable_name, default_value=None, type_constructor=None
)

Converts env variable to specified type, but only if variable is set - otherwise default is returned. If type_constructor is not given - value of type str will be returned.

Source code in inference/core/utils/environment.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def safe_env_to_type(
    variable_name: str,
    default_value: Optional[T] = None,
    type_constructor: Optional[Union[Type[T], Callable[[str], T]]] = None,
) -> Optional[T]:
    """
    Converts env variable to specified type, but only if variable is set - otherwise default is returned.
    If `type_constructor` is not given - value of type str will be returned.
    """
    if variable_name not in os.environ:
        return default_value
    variable_value = os.environ[variable_name]
    if type_constructor is None:
        return variable_value
    return type_constructor(variable_value)

safe_split_value

safe_split_value(value, delimiter=',')

Splits a separated environment variable into a list.

Parameters:

Name Type Description Default
value str

The environment variable value to be split.

required
delimiter str

Delimiter to be used

','

Returns:

Type Description
Optional[List[str]]

list or None: The split values as a list, or None if the input is None.

Source code in inference/core/utils/environment.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def safe_split_value(value: Optional[str], delimiter: str = ",") -> Optional[List[str]]:
    """
    Splits a separated environment variable into a list.

    Args:
        value (str): The environment variable value to be split.
        delimiter(str): Delimiter to be used

    Returns:
        list or None: The split values as a list, or None if the input is None.
    """
    if value is None:
        return None
    else:
        return value.split(delimiter)

str2bool

str2bool(value)

Converts an environment variable to a boolean value.

Parameters:

Name Type Description Default
value str or bool

The environment variable value to be converted.

required

Returns:

Name Type Description
bool bool

The converted boolean value.

Raises:

Type Description
InvalidEnvironmentVariableError

If the value is not 'true', 'false', or a boolean.

Source code in inference/core/utils/environment.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def str2bool(value: Any) -> bool:
    """
    Converts an environment variable to a boolean value.

    Args:
        value (str or bool): The environment variable value to be converted.

    Returns:
        bool: The converted boolean value.

    Raises:
        InvalidEnvironmentVariableError: If the value is not 'true', 'false', or a boolean.
    """
    if isinstance(value, bool):
        return value
    if not issubclass(type(value), str):
        raise InvalidEnvironmentVariableError(
            f"Expected a boolean environment variable (true or false) but got '{value}'"
        )
    if value.lower() == "true":
        return True
    elif value.lower() == "false":
        return False
    else:
        raise InvalidEnvironmentVariableError(
            f"Expected a boolean environment variable (true or false) but got '{value}'"
        )

inference.core.utils.file_system

Classes

AtomicPath

Context manager for atomic file writes.

Ensures that files are either written completely or not at all, preventing partial/corrupted files from power failures or crashes.

Usage

with AtomicPath(target_path, allow_override=False) as temp_path: # Write to temp_path with open(temp_path, 'w') as f: f.write(data)

File is atomically moved to target_path on successful exit
Source code in inference/core/utils/file_system.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class AtomicPath:
    """Context manager for atomic file writes.

    Ensures that files are either written completely or not at all,
    preventing partial/corrupted files from power failures or crashes.

    Usage:
        with AtomicPath(target_path, allow_override=False) as temp_path:
            # Write to temp_path
            with open(temp_path, 'w') as f:
                f.write(data)
        # File is atomically moved to target_path on successful exit
    """

    def __init__(self, target_path: str, allow_override: bool = False):
        self.target_path = target_path
        self.allow_override = allow_override
        self.temp_path: Optional[str] = None
        self.temp_file = None

    def __enter__(self) -> str:
        ensure_write_is_allowed(
            path=self.target_path, allow_override=self.allow_override
        )
        ensure_parent_dir_exists(path=self.target_path)

        dir_name = os.path.dirname(os.path.abspath(self.target_path))
        base_name = os.path.basename(self.target_path)
        self.temp_file = tempfile.NamedTemporaryFile(
            dir=dir_name, prefix=".tmp_", suffix="_" + base_name, delete=False
        )
        self.temp_path = self.temp_file.name
        self.temp_file.close()
        return self.temp_path

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            try:
                if os.name == "nt":  # Windows
                    if os.path.exists(self.target_path):
                        os.remove(self.target_path)
                    os.rename(self.temp_path, self.target_path)
                else:  # POSIX
                    os.replace(self.temp_path, self.target_path)
            except Exception:
                try:
                    os.unlink(self.temp_path)
                except OSError:
                    pass
                raise
        else:
            # Error occurred - clean up temp file
            try:
                os.unlink(self.temp_path)
            except OSError:
                pass
        return False  # Don't suppress exceptions

inference.core.utils.image_utils

Classes

Functions

attempt_loading_image_from_string

attempt_loading_image_from_string(
    value, cv_imread_flags=cv2.IMREAD_COLOR
)

Attempt to load an image from a string.

Parameters:

Name Type Description Default
value Union[str, bytes, bytearray, _IOBase]

The image data in string format.

required
cv_imread_flags int

OpenCV flags used for image reading.

IMREAD_COLOR

Returns:

Type Description
Tuple[ndarray, bool]

Tuple[np.ndarray, bool]: A tuple of the loaded image in numpy array format and a boolean flag indicating if the image is in BGR format.

Source code in inference/core/utils/image_utils.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def attempt_loading_image_from_string(
    value: Union[str, bytes, bytearray, _IOBase],
    cv_imread_flags: int = cv2.IMREAD_COLOR,
) -> Tuple[np.ndarray, bool]:
    """
    Attempt to load an image from a string.

    Args:
        value (Union[str, bytes, bytearray, _IOBase]): The image data in string format.
        cv_imread_flags (int): OpenCV flags used for image reading.

    Returns:
        Tuple[np.ndarray, bool]: A tuple of the loaded image in numpy array format and a boolean flag indicating if the image is in BGR format.
    """
    try:
        return load_image_base64(value=value, cv_imread_flags=cv_imread_flags), True
    except:
        pass
    try:
        return (
            load_image_from_encoded_bytes(value=value, cv_imread_flags=cv_imread_flags),
            True,
        )
    except:
        pass
    try:
        return (
            load_image_from_buffer(value=value, cv_imread_flags=cv_imread_flags),
            True,
        )
    except:
        pass
    try:
        return load_image_from_numpy_str(value=value), True
    except InvalidImageTypeDeclared as error:
        raise error
    except InvalidNumpyInput as error:
        raise InputFormatInferenceFailed(
            message="Input image format could not be inferred from string.",
            public_message="Input image format could not be inferred from string.",
        ) from error

choose_image_decoding_flags

choose_image_decoding_flags(disable_preproc_auto_orient)

Choose the appropriate OpenCV image decoding flags.

Parameters:

Name Type Description Default
disable_preproc_auto_orient bool

Flag to disable preprocessing auto-orientation.

required

Returns:

Name Type Description
int int

OpenCV image decoding flags.

Source code in inference/core/utils/image_utils.py
107
108
109
110
111
112
113
114
115
116
117
118
119
def choose_image_decoding_flags(disable_preproc_auto_orient: bool) -> int:
    """Choose the appropriate OpenCV image decoding flags.

    Args:
        disable_preproc_auto_orient (bool): Flag to disable preprocessing auto-orientation.

    Returns:
        int: OpenCV image decoding flags.
    """
    cv_imread_flags = cv2.IMREAD_COLOR
    if disable_preproc_auto_orient:
        cv_imread_flags = cv_imread_flags | cv2.IMREAD_IGNORE_ORIENTATION
    return cv_imread_flags

convert_gray_image_to_bgr

convert_gray_image_to_bgr(image)

Convert a grayscale image to BGR format.

Parameters:

Name Type Description Default
image ndarray

The grayscale image.

required

Returns:

Type Description
ndarray

np.ndarray: The converted BGR image.

Source code in inference/core/utils/image_utils.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
def convert_gray_image_to_bgr(image: np.ndarray) -> np.ndarray:
    """
    Convert a grayscale image to BGR format.

    Args:
        image (np.ndarray): The grayscale image.

    Returns:
        np.ndarray: The converted BGR image.
    """

    if len(image.shape) == 2 or image.shape[2] == 1:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    return image

encode_image_to_jpeg_bytes

encode_image_to_jpeg_bytes(image, jpeg_quality=90)

Encode a numpy image to JPEG format in bytes.

Parameters:

Name Type Description Default
image ndarray

The numpy array representing a BGR image.

required
jpeg_quality int

Quality of the JPEG image.

90

Returns:

Name Type Description
bytes bytes

The JPEG encoded image.

Source code in inference/core/utils/image_utils.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def encode_image_to_jpeg_bytes(image: np.ndarray, jpeg_quality: int = 90) -> bytes:
    """
    Encode a numpy image to JPEG format in bytes.

    Args:
        image (np.ndarray): The numpy array representing a BGR image.
        jpeg_quality (int): Quality of the JPEG image.

    Returns:
        bytes: The JPEG encoded image.
    """
    encoding_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
    _, img_encoded = cv2.imencode(".jpg", image, encoding_param)
    return np.array(img_encoded).tobytes()

extract_image_payload_and_type

extract_image_payload_and_type(value)

Extract the image payload and type from the given value.

This function supports different types of image inputs (e.g., InferenceRequestImage, dict, etc.) and extracts the relevant data and image type for further processing.

Parameters:

Name Type Description Default
value Any

The input value which can be an image or information to derive the image.

required

Returns:

Type Description
Tuple[Any, Optional[ImageType]]

Tuple[Any, Optional[ImageType]]: A tuple containing the extracted image data and the corresponding image type.

Source code in inference/core/utils/image_utils.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def extract_image_payload_and_type(value: Any) -> Tuple[Any, Optional[ImageType]]:
    """Extract the image payload and type from the given value.

    This function supports different types of image inputs (e.g., InferenceRequestImage, dict, etc.)
    and extracts the relevant data and image type for further processing.

    Args:
        value (Any): The input value which can be an image or information to derive the image.

    Returns:
        Tuple[Any, Optional[ImageType]]: A tuple containing the extracted image data and the corresponding image type.
    """
    image_type = None
    if issubclass(type(value), InferenceRequestImage):
        image_type = value.type
        value = value.value
    elif issubclass(type(value), dict):
        image_type = value.get("type")
        value = value.get("value")
    allowed_payload_types = {e.value for e in ImageType}
    if image_type is None:
        return value, image_type
    if image_type.lower() not in allowed_payload_types:
        raise InvalidImageTypeDeclared(
            message=f"Declared image type: {image_type.lower()} which is not in allowed types: {allowed_payload_types}.",
            public_message="Image declaration contains not recognised image type.",
        )
    return value, ImageType(image_type.lower())

load_image

load_image(value, disable_preproc_auto_orient=False)

Loads an image based on the specified type and value.

Parameters:

Name Type Description Default
value Any

Image value which could be an instance of InferenceRequestImage, a dict with 'type' and 'value' keys, or inferred based on the value's content.

required

Returns:

Type Description
Tuple[ndarray, bool]

Image.Image: The loaded PIL image, converted to RGB.

Raises:

Type Description
NotImplementedError

If the specified image type is not supported.

InvalidNumpyInput

If the numpy input method is used and the input data is invalid.

Source code in inference/core/utils/image_utils.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def load_image(
    value: Any,
    disable_preproc_auto_orient: bool = False,
) -> Tuple[np.ndarray, bool]:
    """Loads an image based on the specified type and value.

    Args:
        value (Any): Image value which could be an instance of InferenceRequestImage,
            a dict with 'type' and 'value' keys, or inferred based on the value's content.

    Returns:
        Image.Image: The loaded PIL image, converted to RGB.

    Raises:
        NotImplementedError: If the specified image type is not supported.
        InvalidNumpyInput: If the numpy input method is used and the input data is invalid.
    """
    cv_imread_flags = choose_image_decoding_flags(
        disable_preproc_auto_orient=disable_preproc_auto_orient
    )
    value, image_type = extract_image_payload_and_type(value=value)
    if image_type is not None:
        np_image, is_bgr = load_image_with_known_type(
            value=value,
            image_type=image_type,
            cv_imread_flags=cv_imread_flags,
        )
    else:
        np_image, is_bgr = load_image_with_inferred_type(
            value, cv_imread_flags=cv_imread_flags
        )
    np_image = convert_gray_image_to_bgr(image=np_image)
    logger.debug(f"Loaded inference image. Shape: {getattr(np_image, 'shape', None)}")
    return np_image, is_bgr

load_image_base64

load_image_base64(value, cv_imread_flags=cv2.IMREAD_COLOR)

Loads an image from a base64 encoded string using OpenCV.

Parameters:

Name Type Description Default
value str

Base64 encoded string representing the image.

required

Returns:

Type Description
ndarray

np.ndarray: The loaded image as a numpy array.

Source code in inference/core/utils/image_utils.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def load_image_base64(
    value: Union[str, bytes], cv_imread_flags=cv2.IMREAD_COLOR
) -> np.ndarray:
    """Loads an image from a base64 encoded string using OpenCV.

    Args:
        value (str): Base64 encoded string representing the image.

    Returns:
        np.ndarray: The loaded image as a numpy array.
    """
    # New routes accept images via json body (str), legacy routes accept bytes which need to be decoded as strings
    if not isinstance(value, str):
        try:
            value = value.decode("utf-8")
        except UnicodeDecodeError:
            raise InputImageLoadError(
                message="Could not decode image bytes as base64 string - the payload appears to be raw image bytes, not a base64-encoded string.",
                public_message="Invalid base64 input: the image payload contains raw bytes instead of a base64-encoded string. Please base64-encode the image before sending.",
            )
    value = BASE64_DATA_TYPE_PATTERN.sub("", value)
    try:
        value = pybase64.b64decode(value)
    except binascii.Error as error:
        raise InputImageLoadError(
            message="Could not load valid image from base64 string.",
            public_message="Malformed base64 input image.",
        ) from error
    if len(value) == 0:
        raise InputImageLoadError(
            message="Could not load valid image from base64 string.",
            public_message="Empty image payload.",
        )
    image_np = np.frombuffer(value, np.uint8)
    result = cv2.imdecode(image_np, cv_imread_flags)
    if result is None:
        raise InputImageLoadError(
            message="Could not load valid image from base64 string.",
            public_message="Malformed base64 input image.",
        )
    return result

load_image_from_buffer

load_image_from_buffer(
    value, cv_imread_flags=cv2.IMREAD_COLOR
)

Loads an image from a multipart-encoded input.

Parameters:

Name Type Description Default
value Any

Multipart-encoded input representing the image.

required

Returns:

Type Description
ndarray

Image.Image: The loaded PIL image.

Source code in inference/core/utils/image_utils.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def load_image_from_buffer(
    value: _IOBase,
    cv_imread_flags: int = cv2.IMREAD_COLOR,
) -> np.ndarray:
    """Loads an image from a multipart-encoded input.

    Args:
        value (Any): Multipart-encoded input representing the image.

    Returns:
        Image.Image: The loaded PIL image.
    """
    value.seek(0)
    image_np = np.frombuffer(value.read(), np.uint8)
    result = cv2.imdecode(image_np, cv_imread_flags)
    if result is None:
        raise InputImageLoadError(
            message="Could not load valid image from buffer.",
            public_message="Could not decode bytes into image.",
        )
    return result

load_image_from_encoded_bytes

load_image_from_encoded_bytes(
    value, cv_imread_flags=cv2.IMREAD_COLOR
)

Load an image from encoded bytes.

Parameters:

Name Type Description Default
value bytes

The byte sequence representing the image.

required
cv_imread_flags int

OpenCV flags used for image reading.

IMREAD_COLOR

Returns:

Type Description
ndarray

np.ndarray: The loaded image as a numpy array.

Source code in inference/core/utils/image_utils.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def load_image_from_encoded_bytes(
    value: bytes, cv_imread_flags: int = cv2.IMREAD_COLOR
) -> np.ndarray:
    """
    Load an image from encoded bytes.

    Args:
        value (bytes): The byte sequence representing the image.
        cv_imread_flags (int): OpenCV flags used for image reading.

    Returns:
        np.ndarray: The loaded image as a numpy array.
    """
    image_np = np.asarray(bytearray(value), dtype=np.uint8)
    image = cv2.imdecode(image_np, cv_imread_flags)
    if image is None:
        raise InputImageLoadError(
            message=f"Could not decode bytes as image.",
            public_message="Data is not image.",
        )
    return image

load_image_from_numpy_str

load_image_from_numpy_str(value)

Loads an image from a numpy array string.

Parameters:

Name Type Description Default
value Union[bytes, str]

Base64 string or byte sequence representing the pickled numpy array of the image.

required

Returns:

Type Description
ndarray

Image.Image: The loaded PIL image.

Raises:

Type Description
InvalidNumpyInput

If the numpy data is invalid.

Source code in inference/core/utils/image_utils.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def load_image_from_numpy_str(value: Union[bytes, str]) -> np.ndarray:
    """Loads an image from a numpy array string.

    Args:
        value (Union[bytes, str]): Base64 string or byte sequence representing the pickled numpy array of the image.

    Returns:
        Image.Image: The loaded PIL image.

    Raises:
        InvalidNumpyInput: If the numpy data is invalid.
    """
    if not ALLOW_NUMPY_INPUT:
        raise InvalidImageTypeDeclared(
            message=f"NumPy image type is not supported in this configuration of `inference`.",
            public_message=f"NumPy image type is not supported in this configuration of `inference`.",
        )
    try:
        if isinstance(value, str):
            value = pybase64.b64decode(value)
        data = pickle.loads(value)
    except (EOFError, TypeError, pickle.UnpicklingError, binascii.Error) as error:
        raise InvalidNumpyInput(
            message=f"Could not unpickle image data. Cause: {error}",
            public_message="Could not deserialize pickle payload.",
        ) from error
    validate_numpy_image(data=data)
    return data

load_image_from_url

load_image_from_url(
    value, cv_imread_flags=cv2.IMREAD_COLOR
)

Loads an image from a given URL.

Parameters:

Name Type Description Default
value str

URL of the image.

required

Returns:

Type Description
ndarray

Image.Image: The loaded PIL image.

Source code in inference/core/utils/image_utils.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def load_image_from_url(
    value: str, cv_imread_flags: int = cv2.IMREAD_COLOR
) -> np.ndarray:
    """Loads an image from a given URL.

    Args:
        value (str): URL of the image.

    Returns:
        Image.Image: The loaded PIL image.
    """
    _ensure_url_input_allowed()
    try:
        parsed_url = urllib.parse.urlparse(value)
    except ValueError as error:
        message = "Provided image URL is invalid"
        raise InputImageLoadError(
            message=message,
            public_message=message,
        ) from error
    _ensure_resource_schema_allowed(schema=parsed_url.scheme)
    domain_extraction_result = tldextract.TLDExtract(suffix_list_urls=())(
        parsed_url.netloc
    )  # we get rid of potential ports and parse FQDNs
    _ensure_resource_fqdn_allowed(fqdn=domain_extraction_result.fqdn)
    address_parts_concatenated = _concatenate_chunks_of_network_location(
        extraction_result=domain_extraction_result
    )  # concatenation of chunks - even if there is no FQDN, but address
    # it allows white-/black-list verification
    _ensure_location_matches_destination_whitelist(
        destination=address_parts_concatenated
    )
    _ensure_location_matches_destination_blacklist(
        destination=address_parts_concatenated
    )
    try:
        response = requests.get(value, stream=True)
        api_key_safe_raise_for_status(response=response)
        return load_image_from_encoded_bytes(
            value=response.content, cv_imread_flags=cv_imread_flags
        )
    except (RequestException, ConnectionError) as error:
        raise InputImageLoadError(
            message=f"Could not load image from url: {value}. Details: {error}",
            public_message="Data pointed by URL could not be decoded into image.",
        )

load_image_with_inferred_type

load_image_with_inferred_type(
    value, cv_imread_flags=cv2.IMREAD_COLOR
)

Load an image by inferring its type.

Parameters:

Name Type Description Default
value Any

The image data.

required
cv_imread_flags int

Flags used for OpenCV's imread function.

IMREAD_COLOR

Returns:

Type Description
Tuple[ndarray, bool]

Tuple[np.ndarray, bool]: Loaded image as a numpy array and a boolean indicating if the image is in BGR format.

Raises:

Type Description
NotImplementedError

If the image type could not be inferred.

Source code in inference/core/utils/image_utils.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def load_image_with_inferred_type(
    value: Any,
    cv_imread_flags: int = cv2.IMREAD_COLOR,
) -> Tuple[np.ndarray, bool]:
    """Load an image by inferring its type.

    Args:
        value (Any): The image data.
        cv_imread_flags (int): Flags used for OpenCV's imread function.

    Returns:
        Tuple[np.ndarray, bool]: Loaded image as a numpy array and a boolean indicating if the image is in BGR format.

    Raises:
        NotImplementedError: If the image type could not be inferred.
    """
    if isinstance(value, (np.ndarray, np.generic)):
        validate_numpy_image(data=value)
        return value, True
    elif isinstance(value, Image.Image):
        return np.asarray(value.convert("RGB")), False
    elif isinstance(value, str) and (value.startswith("http")):
        return load_image_from_url(value=value, cv_imread_flags=cv_imread_flags), True
    elif (
        isinstance(value, str)
        and ALLOW_LOADING_IMAGES_FROM_LOCAL_FILESYSTEM
        and os.path.isfile(value)
    ):
        return cv2.imread(value, cv_imread_flags), True
    else:
        return attempt_loading_image_from_string(
            value=value, cv_imread_flags=cv_imread_flags
        )

load_image_with_known_type

load_image_with_known_type(
    value, image_type, cv_imread_flags=cv2.IMREAD_COLOR
)

Load an image using the known image type.

Supports various image types (e.g., NUMPY, PILLOW, etc.) and loads them into a numpy array format.

Parameters:

Name Type Description Default
value Any

The image data.

required
image_type ImageType

The type of the image.

required
cv_imread_flags int

Flags used for OpenCV's imread function.

IMREAD_COLOR

Returns:

Type Description
Tuple[ndarray, bool]

Tuple[np.ndarray, bool]: A tuple of the loaded image as a numpy array and a boolean indicating if the image is in BGR format.

Source code in inference/core/utils/image_utils.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def load_image_with_known_type(
    value: Any,
    image_type: ImageType,
    cv_imread_flags: int = cv2.IMREAD_COLOR,
) -> Tuple[np.ndarray, bool]:
    """Load an image using the known image type.

    Supports various image types (e.g., NUMPY, PILLOW, etc.) and loads them into a numpy array format.

    Args:
        value (Any): The image data.
        image_type (ImageType): The type of the image.
        cv_imread_flags (int): Flags used for OpenCV's imread function.

    Returns:
        Tuple[np.ndarray, bool]: A tuple of the loaded image as a numpy array and a boolean indicating if the image is in BGR format.
    """
    if image_type is ImageType.FILE and not ALLOW_LOADING_IMAGES_FROM_LOCAL_FILESYSTEM:
        raise InputImageLoadError(
            message="Loading images from local filesystem is disabled.",
            public_message="Loading images from local filesystem is disabled.",
        )
    loader = IMAGE_LOADERS[image_type]
    is_bgr = True if image_type is not ImageType.PILLOW else False
    image = loader(value, cv_imread_flags)
    return image, is_bgr

np_image_to_base64

np_image_to_base64(image)

Convert a numpy image to a base64 encoded byte string.

Parameters:

Name Type Description Default
image ndarray

The numpy array representing an image.

required

Returns:

Name Type Description
bytes bytes

The base64 encoded image.

Source code in inference/core/utils/image_utils.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
@deprecated(
    reason="Method replaced with inference.core.utils.image_utils.encode_image_to_jpeg_bytes"
)
def np_image_to_base64(image: np.ndarray) -> bytes:
    """
    TODO: This function is broken: https://github.com/roboflow/inference/issues/439
    Convert a numpy image to a base64 encoded byte string.

    Args:
        image (np.ndarray): The numpy array representing an image.

    Returns:
        bytes: The base64 encoded image.
    """
    image = Image.fromarray(image)
    with BytesIO() as buffer:
        image = image.convert("RGB")
        image.save(buffer, format="JPEG")
        buffer.seek(0)
        return buffer.getvalue()

validate_numpy_image

validate_numpy_image(data)

Validate if the provided data is a valid numpy image.

Parameters:

Name Type Description Default
data ndarray

The numpy array representing an image.

required

Raises:

Type Description
InvalidNumpyInput

If the provided data is not a valid numpy image.

Source code in inference/core/utils/image_utils.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def validate_numpy_image(data: np.ndarray) -> None:
    """
    Validate if the provided data is a valid numpy image.

    Args:
        data (np.ndarray): The numpy array representing an image.

    Raises:
        InvalidNumpyInput: If the provided data is not a valid numpy image.
    """
    if not issubclass(type(data), np.ndarray):
        raise InvalidNumpyInput(
            message=f"Data provided as input could not be decoded into np.ndarray object.",
            public_message=f"Data provided as input could not be decoded into np.ndarray object.",
        )
    if len(data.shape) != 3 and len(data.shape) != 2:
        raise InvalidNumpyInput(
            message=f"For image given as np.ndarray expected 2 or 3 dimensions, got {len(data.shape)} dimensions.",
            public_message=f"For image given as np.ndarray expected 2 or 3 dimensions.",
        )
    if data.shape[-1] != 3 and data.shape[-1] != 1:
        raise InvalidNumpyInput(
            message=f"For image given as np.ndarray expected 1 or 3 channels, got {data.shape[-1]} channels.",
            public_message="For image given as np.ndarray expected 1 or 3 channels.",
        )

xyxy_to_xywh

xyxy_to_xywh(xyxy)

Convert bounding box format from (xmin, ymin, xmax, ymax) to (xcenter, ycenter, width, height).

Parameters:

Name Type Description Default
xyxy List[int]

List containing the coordinates in (xmin, ymin, xmax, ymax) format.

required

Returns:

Type Description

List[int]: List containing the converted coordinates in (xcenter, ycenter, width, height) format.

Source code in inference/core/utils/image_utils.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
def xyxy_to_xywh(xyxy):
    """
    Convert bounding box format from (xmin, ymin, xmax, ymax) to (xcenter, ycenter, width, height).

    Args:
        xyxy (List[int]): List containing the coordinates in (xmin, ymin, xmax, ymax) format.

    Returns:
        List[int]: List containing the converted coordinates in (xcenter, ycenter, width, height) format.
    """
    x_temp = (xyxy[0] + xyxy[2]) / 2
    y_temp = (xyxy[1] + xyxy[3]) / 2
    w_temp = abs(xyxy[0] - xyxy[2])
    h_temp = abs(xyxy[1] - xyxy[3])

    return [int(x_temp), int(y_temp), int(w_temp), int(h_temp)]

inference.core.utils.onnx

Functions

get_onnxruntime_execution_providers

get_onnxruntime_execution_providers(value)

Extracts the ONNX runtime execution providers from the given string.

The input string is expected to be a comma-separated list, possibly enclosed within square brackets and containing single quotes.

Parameters:

Name Type Description Default
value str

The string containing the list of ONNX runtime execution providers.

required

Returns:

Type Description
List[str]

List[str]: A list of strings representing each execution provider.

Source code in inference/core/utils/onnx.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def get_onnxruntime_execution_providers(value: str) -> List[str]:
    """Extracts the ONNX runtime execution providers from the given string.

    The input string is expected to be a comma-separated list, possibly enclosed
    within square brackets and containing single quotes.

    Args:
        value (str): The string containing the list of ONNX runtime execution providers.

    Returns:
        List[str]: A list of strings representing each execution provider.
    """
    if len(value) == 0:
        return []
    value = value.replace("[", "").replace("]", "").replace("'", "").replace(" ", "")
    return value.split(",")

inference.core.utils.postprocess

Functions

cosine_similarity

cosine_similarity(a, b)

Compute the cosine similarity between two vectors.

Parameters:

Name Type Description Default
a ndarray

Vector A.

required
b ndarray

Vector B.

required

Returns:

Name Type Description
float Union[number, ndarray]

Cosine similarity between vectors A and B.

Source code in inference/core/utils/postprocess.py
14
15
16
17
18
19
20
21
22
23
24
25
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> Union[np.number, np.ndarray]:
    """
    Compute the cosine similarity between two vectors.

    Args:
        a (np.ndarray): Vector A.
        b (np.ndarray): Vector B.

    Returns:
        float: Cosine similarity between vectors A and B.
    """
    return np.dot(a, b) / np.sqrt(np.vdot(a, a) * np.vdot(b, b))

crop_mask

crop_mask(masks, boxes)

"Crop" predicted masks by zeroing out everything not in the predicted bbox. Vectorized by Chong (thanks Chong).

Source code in inference/core/utils/postprocess.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def crop_mask(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray:
    """
    "Crop" predicted masks by zeroing out everything not in the predicted bbox.
    Vectorized by Chong (thanks Chong).

    Args:
        - masks should be a size [h, w, n] tensor of masks
        - boxes should be a size [n, 4] tensor of bbox coords in relative point form
    """

    n, h, w = masks.shape
    x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)  # x1 shape(1,1,n)
    r = np.arange(w, dtype=x1.dtype)[None, None, :]  # rows shape(1,w,1)
    c = np.arange(h, dtype=x1.dtype)[None, :, None]  # cols shape(h,1,1)

    masks = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
    return masks

get_static_crop_dimensions

get_static_crop_dimensions(
    orig_shape, preproc, disable_preproc_static_crop=False
)

Generates a transformation based on preprocessing configuration.

Parameters:

Name Type Description Default
orig_shape tuple

The original shape of the object (e.g., image) - (height, width).

required
preproc dict

Preprocessing configuration dictionary, containing information such as static cropping.

required
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False

Returns:

Name Type Description
tuple Tuple[Tuple[int, int], Tuple[int, int]]

A tuple containing the shift in the x and y directions, and the updated original shape after cropping.

Source code in inference/core/utils/postprocess.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def get_static_crop_dimensions(
    orig_shape: Tuple[int, int],
    preproc: dict,
    disable_preproc_static_crop: bool = False,
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    """
    Generates a transformation based on preprocessing configuration.

    Args:
        orig_shape (tuple): The original shape of the object (e.g., image) - (height, width).
        preproc (dict): Preprocessing configuration dictionary, containing information such as static cropping.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

    Returns:
        tuple: A tuple containing the shift in the x and y directions, and the updated original shape after cropping.
    """
    try:
        if static_crop_should_be_applied(
            preprocessing_config=preproc,
            disable_preproc_static_crop=disable_preproc_static_crop,
        ):
            x_min, y_min, x_max, y_max = standardise_static_crop(
                static_crop_config=preproc[STATIC_CROP_KEY]
            )
        else:
            x_min, y_min, x_max, y_max = 0, 0, 1, 1
        crop_shift_x, crop_shift_y = (
            round(x_min * orig_shape[1]),
            round(y_min * orig_shape[0]),
        )
        cropped_percent_x = x_max - x_min
        cropped_percent_y = y_max - y_min
        orig_shape = (
            round(orig_shape[0] * cropped_percent_y),
            round(orig_shape[1] * cropped_percent_x),
        )
        return (crop_shift_x, crop_shift_y), orig_shape
    except KeyError as error:
        raise PostProcessingError(
            f"Could not find a proper configuration key {error} in post-processing."
        )

mask2multipoly

mask2multipoly(mask)

Find all contours in the mask and return them as a float32 array.

Parameters:

Name Type Description Default
mask ndarray

A binary mask.

required

Returns:

Type Description
ndarray

np.ndarray: Contours represented as a float32 array.

Source code in inference/core/utils/postprocess.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def mask2multipoly(mask: np.ndarray) -> np.ndarray:
    """
    Find all contours in the mask and return them as a float32 array.

    Args:
        mask (np.ndarray): A binary mask.

    Returns:
        np.ndarray: Contours represented as a float32 array.
    """
    contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
    if contours:
        contours = [c.reshape(-1, 2).astype("float32") for c in contours]
    else:
        contours = [np.zeros((0, 2)).astype("float32")]
    return contours

mask2poly

mask2poly(mask)

Find contours in the mask and return them as a float32 array.

Parameters:

Name Type Description Default
mask ndarray

A binary mask.

required

Returns:

Type Description
ndarray

np.ndarray: Contours represented as a float32 array.

Source code in inference/core/utils/postprocess.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def mask2poly(mask: np.ndarray) -> np.ndarray:
    """
    Find contours in the mask and return them as a float32 array.

    Args:
        mask (np.ndarray): A binary mask.

    Returns:
        np.ndarray: Contours represented as a float32 array.
    """
    contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
    if contours:
        contours = np.array(
            contours[np.array([len(x) for x in contours]).argmax()]
        ).reshape(-1, 2)
    else:
        contours = np.zeros((0, 2))
    return contours.astype("float32")

masks2multipoly

masks2multipoly(masks)

Converts binary masks to polygonal segments.

Parameters:

Name Type Description Default
masks ndarray

A set of binary masks, where masks are multiplied by 255 and converted to uint8 type.

required

Returns:

Name Type Description
list List[ndarray]

A list of segments, where each segment is obtained by converting the corresponding mask.

Source code in inference/core/utils/postprocess.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def masks2multipoly(masks: np.ndarray) -> List[np.ndarray]:
    """Converts binary masks to polygonal segments.

    Args:
        masks (numpy.ndarray): A set of binary masks, where masks are multiplied by 255 and converted to uint8 type.

    Returns:
        list: A list of segments, where each segment is obtained by converting the corresponding mask.
    """
    segments = []
    # Process per-mask to avoid allocating an entire N x H x W uint8 copy
    for mask in masks:
        # Fast-path: bool -> zero-copy uint8 view
        if mask.dtype == np.bool_:
            m_uint8 = mask
            if not m_uint8.flags.c_contiguous:
                m_uint8 = np.ascontiguousarray(m_uint8)
            m_uint8 = m_uint8.view(np.uint8)
        elif mask.dtype == np.uint8:
            m_uint8 = mask if mask.flags.c_contiguous else np.ascontiguousarray(mask)
        else:
            # Fallback: threshold to bool then view as uint8 (may allocate once)
            m_bool = mask > 0
            if not m_bool.flags.c_contiguous:
                m_bool = np.ascontiguousarray(m_bool)
            m_uint8 = m_bool.view(np.uint8)

        # Quickly skip empty masks
        if not np.any(m_uint8):
            segments.append([np.zeros((0, 2), dtype=np.float32)])
            continue

        segments.append(mask2multipoly(m_uint8))
    return segments

masks2poly

masks2poly(masks)

Converts binary masks to polygonal segments.

Parameters:

Name Type Description Default
masks ndarray

A set of binary masks, where masks are multiplied by 255 and converted to uint8 type.

required

Returns:

Name Type Description
list List[ndarray]

A list of segments, where each segment is obtained by converting the corresponding mask.

Source code in inference/core/utils/postprocess.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def masks2poly(masks: np.ndarray) -> List[np.ndarray]:
    """Converts binary masks to polygonal segments.

    Args:
        masks (numpy.ndarray): A set of binary masks, where masks are multiplied by 255 and converted to uint8 type.

    Returns:
        list: A list of segments, where each segment is obtained by converting the corresponding mask.
    """
    segments = []
    # Process per-mask to avoid allocating an entire N x H x W uint8 copy
    for mask in masks:
        # Fast-path: bool -> zero-copy uint8 view
        if mask.dtype == np.bool_:
            m_uint8 = mask
            if not m_uint8.flags.c_contiguous:
                m_uint8 = np.ascontiguousarray(m_uint8)
            m_uint8 = m_uint8.view(np.uint8)
        elif mask.dtype == np.uint8:
            m_uint8 = mask if mask.flags.c_contiguous else np.ascontiguousarray(mask)
        else:
            # Fallback: threshold to bool then view as uint8 (may allocate once)
            m_bool = mask > 0
            if not m_bool.flags.c_contiguous:
                m_bool = np.ascontiguousarray(m_bool)
            m_uint8 = m_bool.view(np.uint8)

        # Quickly skip empty masks
        if not np.any(m_uint8):
            segments.append(np.zeros((0, 2), dtype=np.float32))
            continue

        segments.append(mask2poly(m_uint8))
    return segments

post_process_bboxes

post_process_bboxes(
    predictions,
    infer_shape,
    img_dims,
    preproc,
    disable_preproc_static_crop=False,
    resize_method="Stretch to",
)

Postprocesses each patch of detections by scaling them to the original image coordinates and by shifting them based on a static crop preproc (if applied).

Parameters:

Name Type Description Default
predictions List[List[List[float]]]

The predictions output from NMS, indices are: batch x prediction x [x1, y1, x2, y2, ...].

required
infer_shape Tuple[int, int]

The shape of the inference image.

required
img_dims List[Tuple[int, int]]

The dimensions of the original image for each batch, indices are: batch x [height, width].

required
preproc dict

Preprocessing configuration dictionary.

required
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False
resize_method str

Resize method for image. Defaults to "Stretch to".

'Stretch to'

Returns:

Type Description
List[List[List[float]]]

List[List[List[float]]]: The scaled and shifted predictions, indices are: batch x prediction x [x1, y1, x2, y2, ...].

Source code in inference/core/utils/postprocess.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def post_process_bboxes(
    predictions: List[List[List[float]]],
    infer_shape: Tuple[int, int],
    img_dims: List[Tuple[int, int]],
    preproc: dict,
    disable_preproc_static_crop: bool = False,
    resize_method: str = "Stretch to",
) -> List[List[List[float]]]:
    """
    Postprocesses each patch of detections by scaling them to the original image coordinates and by shifting them based on a static crop preproc (if applied).

    Args:
        predictions (List[List[List[float]]]): The predictions output from NMS, indices are: batch x prediction x [x1, y1, x2, y2, ...].
        infer_shape (Tuple[int, int]): The shape of the inference image.
        img_dims (List[Tuple[int, int]]): The dimensions of the original image for each batch, indices are: batch x [height, width].
        preproc (dict): Preprocessing configuration dictionary.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.
        resize_method (str, optional): Resize method for image. Defaults to "Stretch to".

    Returns:
        List[List[List[float]]]: The scaled and shifted predictions, indices are: batch x prediction x [x1, y1, x2, y2, ...].
    """

    # Get static crop params
    scaled_predictions = []
    # Loop through batches
    for i, batch_predictions in enumerate(predictions):
        if len(batch_predictions) == 0:
            scaled_predictions.append([])
            continue
        np_batch_predictions = np.array(batch_predictions)
        # Get bboxes from predictions (x1,y1,x2,y2)
        predicted_bboxes = np_batch_predictions[:, :4]
        (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions(
            img_dims[i],
            preproc,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )
        if resize_method == "Stretch to":
            predicted_bboxes = stretch_bboxes(
                predicted_bboxes=predicted_bboxes,
                infer_shape=infer_shape,
                origin_shape=origin_shape,
            )
        elif (
            resize_method == "Fit (black edges) in"
            or resize_method == "Fit (white edges) in"
            or resize_method == "Fit (grey edges) in"
        ):
            predicted_bboxes = undo_image_padding_for_predicted_boxes(
                predicted_bboxes=predicted_bboxes,
                infer_shape=infer_shape,
                origin_shape=origin_shape,
            )
        predicted_bboxes = clip_boxes_coordinates(
            predicted_bboxes=predicted_bboxes,
            origin_shape=origin_shape,
        )
        predicted_bboxes = shift_bboxes(
            bboxes=predicted_bboxes,
            shift_x=crop_shift_x,
            shift_y=crop_shift_y,
        )
        np_batch_predictions[:, :4] = predicted_bboxes
        scaled_predictions.append(np_batch_predictions.tolist())
    return scaled_predictions

post_process_keypoints

post_process_keypoints(
    predictions,
    keypoints_start_index,
    infer_shape,
    img_dims,
    preproc,
    disable_preproc_static_crop=False,
    resize_method="Stretch to",
)

Scales and shifts keypoints based on the given image shapes and preprocessing method.

This function performs polygon scaling and shifting based on the specified resizing method and pre-processing steps. The polygons are transformed according to the ratio and padding between two images.

Parameters:

Name Type Description Default
predictions List[List[List[float]]]

predictions from model

required
keypoints_start_index int

offset in the 3rd dimension pointing where in the prediction start keypoints [(x, y, cfg), ...] for each keypoint class

required
img_dims list of tuple of int

Shape of the source image (height, width).

required
infer_shape tuple of int

Shape of the target image (height, width).

required
preproc object

Preprocessing details used for generating the transformation.

required
resize_method str

Resizing method, either "Stretch to", "Fit (black edges) in", "Fit (white edges) in", or "Fit (grey edges) in". Defaults to "Stretch to".

'Stretch to'
disable_preproc_static_crop bool

flag to disable static crop

False
Source code in inference/core/utils/postprocess.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def post_process_keypoints(
    predictions: List[List[List[float]]],
    keypoints_start_index: int,
    infer_shape: Tuple[int, int],
    img_dims: List[Tuple[int, int]],
    preproc: dict,
    disable_preproc_static_crop: bool = False,
    resize_method: str = "Stretch to",
) -> List[List[List[float]]]:
    """Scales and shifts keypoints based on the given image shapes and preprocessing method.

    This function performs polygon scaling and shifting based on the specified resizing method and
    pre-processing steps. The polygons are transformed according to the ratio and padding between two images.

    Args:
        predictions: predictions from model
        keypoints_start_index: offset in the 3rd dimension pointing where in the prediction start keypoints [(x, y, cfg), ...] for each keypoint class
        img_dims list of (tuple of int): Shape of the source image (height, width).
        infer_shape (tuple of int): Shape of the target image (height, width).
        preproc (object): Preprocessing details used for generating the transformation.
        resize_method (str, optional): Resizing method, either "Stretch to", "Fit (black edges) in", "Fit (white edges) in", or "Fit (grey edges) in". Defaults to "Stretch to".
        disable_preproc_static_crop: flag to disable static crop
    Returns:
        list of list of list: predictions with post-processed keypoints
    """
    # Get static crop params
    scaled_predictions = []
    # Loop through batches
    for i, batch_predictions in enumerate(predictions):
        if len(batch_predictions) == 0:
            scaled_predictions.append([])
            continue
        np_batch_predictions = np.array(batch_predictions)
        keypoints = np_batch_predictions[:, keypoints_start_index:]
        (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions(
            img_dims[i],
            preproc,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )
        if resize_method == "Stretch to":
            keypoints = stretch_keypoints(
                keypoints=keypoints,
                infer_shape=infer_shape,
                origin_shape=origin_shape,
            )
        elif (
            resize_method == "Fit (black edges) in"
            or resize_method == "Fit (white edges) in"
            or resize_method == "Fit (grey edges) in"
        ):
            keypoints = undo_image_padding_for_predicted_keypoints(
                keypoints=keypoints,
                infer_shape=infer_shape,
                origin_shape=origin_shape,
            )
        keypoints = clip_keypoints_coordinates(
            keypoints=keypoints, origin_shape=origin_shape
        )
        keypoints = shift_keypoints(
            keypoints=keypoints, shift_x=crop_shift_x, shift_y=crop_shift_y
        )
        np_batch_predictions[:, keypoints_start_index:] = keypoints
        scaled_predictions.append(np_batch_predictions.tolist())
    return scaled_predictions

post_process_polygons

post_process_polygons(
    origin_shape,
    polys,
    infer_shape,
    preproc,
    resize_method="Stretch to",
)

Scales and shifts polygons based on the given image shapes and preprocessing method.

This function performs polygon scaling and shifting based on the specified resizing method and pre-processing steps. The polygons are transformed according to the ratio and padding between two images.

Parameters:

Name Type Description Default
origin_shape tuple of int

Shape of the source image (height, width).

required
infer_shape tuple of int

Shape of the target image (height, width).

required
polys list of list of tuple

List of polygons, where each polygon is represented by a list of (x, y) coordinates.

required
preproc object

Preprocessing details used for generating the transformation.

required
resize_method str

Resizing method, either "Stretch to", "Fit (black edges) in", "Fit (white edges) in", or "Fit (grey edges) in". Defaults to "Stretch to".

'Stretch to'

Returns:

Type Description
List[List[Tuple[float, float]]]

list of list of tuple: A list of shifted and scaled polygons.

Source code in inference/core/utils/postprocess.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def post_process_polygons(
    origin_shape: Tuple[int, int],
    polys: List[List[Tuple[float, float]]],
    infer_shape: Tuple[int, int],
    preproc: dict,
    resize_method: str = "Stretch to",
) -> List[List[Tuple[float, float]]]:
    """Scales and shifts polygons based on the given image shapes and preprocessing method.

    This function performs polygon scaling and shifting based on the specified resizing method and
    pre-processing steps. The polygons are transformed according to the ratio and padding between two images.

    Args:
        origin_shape (tuple of int): Shape of the source image (height, width).
        infer_shape (tuple of int): Shape of the target image (height, width).
        polys (list of list of tuple): List of polygons, where each polygon is represented by a list of (x, y) coordinates.
        preproc (object): Preprocessing details used for generating the transformation.
        resize_method (str, optional): Resizing method, either "Stretch to", "Fit (black edges) in", "Fit (white edges) in", or "Fit (grey edges) in". Defaults to "Stretch to".

    Returns:
        list of list of tuple: A list of shifted and scaled polygons.
    """
    (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions(
        origin_shape, preproc
    )
    new_polys = []
    if resize_method == "Stretch to":
        width_ratio = origin_shape[1] / infer_shape[1]
        height_ratio = origin_shape[0] / infer_shape[0]
        new_polys = scale_polygons(
            polygons=polys,
            x_scale=width_ratio,
            y_scale=height_ratio,
        )
    elif resize_method in {
        "Fit (black edges) in",
        "Fit (white edges) in",
        "Fit (grey edges) in",
    }:
        new_polys = undo_image_padding_for_predicted_polygons(
            polygons=polys,
            infer_shape=infer_shape,
            origin_shape=origin_shape,
        )
    shifted_polys = []
    for poly in new_polys:
        poly = [(p[0] + crop_shift_x, p[1] + crop_shift_y) for p in poly]
        shifted_polys.append(poly)
    return shifted_polys

process_mask_accurate

process_mask_accurate(protos, masks_in, bboxes, shape)

Returns masks that are the size of the original image.

Parameters:

Name Type Description Default
protos ndarray

Prototype masks.

required
masks_in ndarray

Input masks.

required
bboxes ndarray

Bounding boxes.

required
shape tuple

Target shape.

required

Returns:

Type Description
ndarray

numpy.ndarray: Processed masks.

Source code in inference/core/utils/postprocess.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def process_mask_accurate(
    protos: np.ndarray,
    masks_in: np.ndarray,
    bboxes: np.ndarray,
    shape: Tuple[int, int],
) -> np.ndarray:
    """Returns masks that are the size of the original image.

    Args:
        protos (numpy.ndarray): Prototype masks.
        masks_in (numpy.ndarray): Input masks.
        bboxes (numpy.ndarray): Bounding boxes.
        shape (tuple): Target shape.

    Returns:
        numpy.ndarray: Processed masks.
    """
    masks = preprocess_segmentation_masks(
        protos=protos,
        masks_in=masks_in,
        shape=shape,
    )
    # Order = 1 -> bilinear
    if len(masks.shape) == 2:
        masks = np.expand_dims(masks, axis=0)
    masks = masks.transpose((1, 2, 0))
    masks = cv2.resize(masks, (shape[1], shape[0]), cv2.INTER_LINEAR)
    if len(masks.shape) == 2:
        masks = np.expand_dims(masks, axis=2)
    masks = masks.transpose((2, 0, 1))
    masks = crop_mask(masks, bboxes)
    masks[masks < 0.5] = 0
    return masks

process_mask_fast

process_mask_fast(protos, masks_in, bboxes, shape)

Returns masks in their original size.

Parameters:

Name Type Description Default
protos ndarray

Prototype masks.

required
masks_in ndarray

Input masks.

required
bboxes ndarray

Bounding boxes.

required
shape tuple

Target shape.

required

Returns:

Type Description
ndarray

numpy.ndarray: Processed masks.

Source code in inference/core/utils/postprocess.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def process_mask_fast(
    protos: np.ndarray,
    masks_in: np.ndarray,
    bboxes: np.ndarray,
    shape: Tuple[int, int],
) -> np.ndarray:
    """Returns masks in their original size.

    Args:
        protos (numpy.ndarray): Prototype masks.
        masks_in (numpy.ndarray): Input masks.
        bboxes (numpy.ndarray): Bounding boxes.
        shape (tuple): Target shape.

    Returns:
        numpy.ndarray: Processed masks.
    """
    ih, iw = shape
    c, mh, mw = protos.shape  # CHW
    masks = preprocess_segmentation_masks(
        protos=protos,
        masks_in=masks_in,
        shape=shape,
    )
    down_sampled_boxes = scale_bboxes(
        bboxes=deepcopy(bboxes),
        scale_x=mw / iw,
        scale_y=mh / ih,
    )
    masks = crop_mask(masks, down_sampled_boxes)
    masks[masks < 0.5] = 0
    return masks

process_mask_tradeoff

process_mask_tradeoff(
    protos, masks_in, bboxes, shape, tradeoff_factor
)

Returns masks that are the size of the original image with a tradeoff factor applied.

Parameters:

Name Type Description Default
protos ndarray

Prototype masks.

required
masks_in ndarray

Input masks.

required
bboxes ndarray

Bounding boxes.

required
shape tuple

Target shape.

required
tradeoff_factor float

Tradeoff factor for resizing masks.

required

Returns:

Type Description
ndarray

numpy.ndarray: Processed masks.

Source code in inference/core/utils/postprocess.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def process_mask_tradeoff(
    protos: np.ndarray,
    masks_in: np.ndarray,
    bboxes: np.ndarray,
    shape: Tuple[int, int],
    tradeoff_factor: float,
) -> np.ndarray:
    """Returns masks that are the size of the original image with a tradeoff factor applied.

    Args:
        protos (numpy.ndarray): Prototype masks.
        masks_in (numpy.ndarray): Input masks.
        bboxes (numpy.ndarray): Bounding boxes.
        shape (tuple): Target shape.
        tradeoff_factor (float): Tradeoff factor for resizing masks.

    Returns:
        numpy.ndarray: Processed masks.
    """
    c, mh, mw = protos.shape  # CHW
    masks = preprocess_segmentation_masks(
        protos=protos,
        masks_in=masks_in,
        shape=shape,
    )

    # Order = 1 -> bilinear
    if len(masks.shape) == 2:
        masks = np.expand_dims(masks, axis=0)
    masks = masks.transpose((1, 2, 0))
    ih, iw = shape
    h = int(mh * (1 - tradeoff_factor) + ih * tradeoff_factor)
    w = int(mw * (1 - tradeoff_factor) + iw * tradeoff_factor)
    size = (h, w)
    if tradeoff_factor != 0:
        masks = cv2.resize(masks, size, cv2.INTER_LINEAR)
    if len(masks.shape) == 2:
        masks = np.expand_dims(masks, axis=2)
    masks = masks.transpose((2, 0, 1))
    c, mh, mw = masks.shape
    down_sampled_boxes = scale_bboxes(
        bboxes=deepcopy(bboxes),
        scale_x=mw / iw,
        scale_y=mh / ih,
    )
    masks = crop_mask(masks, down_sampled_boxes)
    masks[masks < 0.5] = 0
    return masks

sigmoid

sigmoid(x)

Computes the sigmoid function for the given input.

The sigmoid function is defined as: f(x) = 1 / (1 + exp(-x))

Parameters:

Name Type Description Default
x float or ndarray

Input value or array for which the sigmoid function is to be computed.

required

Returns:

Type Description
Union[float, number, ndarray]

float or numpy.ndarray: The computed sigmoid value(s).

Source code in inference/core/utils/postprocess.py
685
686
687
688
689
690
691
692
693
694
695
696
697
def sigmoid(x: Union[float, np.ndarray]) -> Union[float, np.number, np.ndarray]:
    """Computes the sigmoid function for the given input.

    The sigmoid function is defined as:
    f(x) = 1 / (1 + exp(-x))

    Args:
        x (float or numpy.ndarray): Input value or array for which the sigmoid function is to be computed.

    Returns:
        float or numpy.ndarray: The computed sigmoid value(s).
    """
    return 1 / (1 + np.exp(-x))

inference.core.utils.preprocess

Functions

letterbox_image

letterbox_image(image, desired_size, color=(0, 0, 0))

Resize and pad image to fit the desired size, preserving its aspect ratio.

Parameters: - image: numpy array representing the image. - desired_size: tuple (width, height) representing the target dimensions. - color: tuple (B, G, R) representing the color to pad with.

Returns: - letterboxed image.

Source code in inference/core/utils/preprocess.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def letterbox_image(
    image: ImageMetaType,
    desired_size: Tuple[int, int],
    color: Tuple[int, int, int] = (0, 0, 0),
) -> ImageMetaType:
    """
    Resize and pad image to fit the desired size, preserving its aspect ratio.

    Parameters:
    - image: numpy array representing the image.
    - desired_size: tuple (width, height) representing the target dimensions.
    - color: tuple (B, G, R) representing the color to pad with.

    Returns:
    - letterboxed image.
    """
    resized_img = resize_image_keeping_aspect_ratio(
        image=image,
        desired_size=desired_size,
    )
    new_height, new_width = (
        resized_img.shape[:2]
        if isinstance(resized_img, np.ndarray)
        else resized_img.shape[-2:]
    )
    top_padding = (desired_size[1] - new_height) // 2
    bottom_padding = desired_size[1] - new_height - top_padding
    left_padding = (desired_size[0] - new_width) // 2
    right_padding = desired_size[0] - new_width - left_padding
    if isinstance(resized_img, np.ndarray):
        return cv2.copyMakeBorder(
            resized_img,
            top_padding,
            bottom_padding,
            left_padding,
            right_padding,
            cv2.BORDER_CONSTANT,
            value=color,
        )
    elif USE_PYTORCH_FOR_PREPROCESSING:
        return torch.nn.functional.pad(
            resized_img,
            (left_padding, right_padding, top_padding, bottom_padding),
            "constant",
            color[0],
        )
    else:
        raise ValueError(
            f"Received an image of unknown type, {type(resized_img)}; "
            "This is most likely a bug. Contact Roboflow team through github issues "
            "(https://github.com/roboflow/inference/issues) providing full context of the problem"
        )

prepare

prepare(
    image,
    preproc,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
)

Prepares an image by applying a series of preprocessing steps defined in the preproc dictionary.

Parameters:

Name Type Description Default
image Image

The input PIL image object.

required
preproc dict

Dictionary containing preprocessing steps. Example: { "resize": {"enabled": true, "width": 416, "height": 416, "format": "Stretch to"}, "static-crop": {"y_min": 25, "x_max": 75, "y_max": 75, "enabled": true, "x_min": 25}, "auto-orient": {"enabled": true}, "grayscale": {"enabled": true}, "contrast": {"enabled": true, "type": "Adaptive Equalization"} }

required
disable_preproc_contrast bool

If true, the contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False

Returns:

Name Type Description
ndarray

PIL.Image.Image: The preprocessed image object.

tuple Tuple[int, int]

The dimensions of the image.

Note

The function uses global flags like DISABLE_PREPROC_AUTO_ORIENT, DISABLE_PREPROC_STATIC_CROP, etc. to conditionally enable or disable certain preprocessing steps.

Source code in inference/core/utils/preprocess.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def prepare(
    image: np.ndarray,
    preproc,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
) -> Tuple[np.ndarray, Tuple[int, int]]:
    """
    Prepares an image by applying a series of preprocessing steps defined in the `preproc` dictionary.

    Args:
        image (PIL.Image.Image): The input PIL image object.
        preproc (dict): Dictionary containing preprocessing steps. Example:
            {
                "resize": {"enabled": true, "width": 416, "height": 416, "format": "Stretch to"},
                "static-crop": {"y_min": 25, "x_max": 75, "y_max": 75, "enabled": true, "x_min": 25},
                "auto-orient": {"enabled": true},
                "grayscale": {"enabled": true},
                "contrast": {"enabled": true, "type": "Adaptive Equalization"}
            }
        disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

    Returns:
        PIL.Image.Image: The preprocessed image object.
        tuple: The dimensions of the image.

    Note:
        The function uses global flags like `DISABLE_PREPROC_AUTO_ORIENT`, `DISABLE_PREPROC_STATIC_CROP`, etc.
        to conditionally enable or disable certain preprocessing steps.
    """
    try:
        if isinstance(image, np.ndarray):
            h, w = image.shape[0:2]
        elif USE_PYTORCH_FOR_PREPROCESSING:
            h, w = image.shape[-2:]
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(image)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

        img_dims = (h, w)
        if static_crop_should_be_applied(
            preprocessing_config=preproc,
            disable_preproc_static_crop=disable_preproc_static_crop,
        ):
            image = take_static_crop(
                image=image, crop_parameters=preproc[STATIC_CROP_KEY]
            )
        if contrast_adjustments_should_be_applied(
            preprocessing_config=preproc,
            disable_preproc_contrast=disable_preproc_contrast,
        ):
            adjustment_type = ContrastAdjustmentType(preproc[CONTRAST_KEY][TYPE_KEY])
            image = apply_contrast_adjustment(
                image=image, adjustment_type=adjustment_type
            )
        if grayscale_conversion_should_be_applied(
            preprocessing_config=preproc,
            disable_preproc_grayscale=disable_preproc_grayscale,
        ):
            image = apply_grayscale_conversion(image=image)
        return image, img_dims
    except KeyError as error:
        raise PreProcessingError(
            f"Pre-processing of image failed due to misconfiguration. Missing key: {error}."
        ) from error

resize_image_keeping_aspect_ratio

resize_image_keeping_aspect_ratio(image, desired_size)

Resize reserving its aspect ratio.

Parameters: - image: numpy array representing the image. - desired_size: tuple (width, height) representing the target dimensions.

Source code in inference/core/utils/preprocess.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def resize_image_keeping_aspect_ratio(
    image: ImageMetaType,
    desired_size: Tuple[int, int],
) -> ImageMetaType:
    """
    Resize reserving its aspect ratio.

    Parameters:
    - image: numpy array representing the image.
    - desired_size: tuple (width, height) representing the target dimensions.
    """
    if isinstance(image, np.ndarray):
        img_ratio = image.shape[1] / image.shape[0]
    elif USE_PYTORCH_FOR_PREPROCESSING:
        img_ratio = image.shape[-1] / image.shape[-2]
    else:
        raise ValueError(
            f"Received an image of unknown type, {type(image)}; "
            "This is most likely a bug. Contact Roboflow team through github issues "
            "(https://github.com/roboflow/inference/issues) providing full context of the problem"
        )
    desired_ratio = desired_size[0] / desired_size[1]

    # Determine the new dimensions
    if img_ratio >= desired_ratio:
        # Resize by width
        new_width = desired_size[0]
        new_height = int(desired_size[0] / img_ratio)
    else:
        # Resize by height
        new_height = desired_size[1]
        new_width = int(desired_size[1] * img_ratio)

    # Resize the image to new dimensions
    if isinstance(image, np.ndarray):
        return cv2.resize(image, (new_width, new_height))
    elif USE_PYTORCH_FOR_PREPROCESSING:
        return torch.nn.functional.interpolate(
            image, size=(new_height, new_width), mode="bilinear"
        )
    else:
        raise ValueError(
            f"Received an image of unknown type, {type(image)}; "
            "This is most likely a bug. Contact Roboflow team through github issues "
            "(https://github.com/roboflow/inference/issues) providing full context of the problem"
        )

inference.core.utils.torchscript_guard

core/workflows/core_steps/analytics/detection_event_log

inference.core.workflows.core_steps.analytics.detection_event_log.v1

Classes

DetectionEvent dataclass

Stores event data for a tracked detection.

Source code in inference/core/workflows/core_steps/analytics/detection_event_log/v1.py
36
37
38
39
40
41
42
43
44
45
46
47
@dataclass
class DetectionEvent:
    """Stores event data for a tracked detection."""

    tracker_id: int
    class_name: str
    first_seen_frame: int
    first_seen_timestamp: float
    last_seen_frame: int
    last_seen_timestamp: float
    frame_count: int = 1
    logged: bool = False

DetectionEventLogBlockV1

Bases: WorkflowBlock

Block that tracks detection events over time.

Maintains a dictionary of tracked objects with: - First seen timestamp and frame - Last seen timestamp and frame - Class name - Frame count (number of frames the object has been seen)

Only logs objects that have been seen for at least frame_threshold frames. Runs cleanup every flush_interval frames, removing events not seen for stale_frames.

Source code in inference/core/workflows/core_steps/analytics/detection_event_log/v1.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class DetectionEventLogBlockV1(WorkflowBlock):
    """
    Block that tracks detection events over time.

    Maintains a dictionary of tracked objects with:
    - First seen timestamp and frame
    - Last seen timestamp and frame
    - Class name
    - Frame count (number of frames the object has been seen)

    Only logs objects that have been seen for at least frame_threshold frames.
    Runs cleanup every flush_interval frames, removing events not seen for stale_frames.
    """

    def __init__(self):
        # Dict[video_id, Dict[tracker_id, DetectionEvent]]
        self._event_logs: Dict[str, Dict[int, DetectionEvent]] = {}
        # Dict[video_id, last_flush_frame]
        self._last_flush_frame: Dict[str, int] = {}
        # Dict[video_id, frame_count] - internal frame counter (increments each run)
        self._frame_count: Dict[str, int] = {}
        # Dict[video_id, last_access_frame] - tracks when each video was last accessed (global frame count)
        self._last_access: Dict[str, int] = {}
        # Dict[video_id, reference_timestamp] - stores extracted reference timestamp per video
        self._reference_timestamps: Dict[str, float] = {}
        # Global frame counter for tracking video access order
        self._global_frame: int = 0

    @classmethod
    def get_manifest(cls) -> Type[WorkflowBlockManifest]:
        return BlockManifest

    def _get_relative_time(
        self, current_frame: int, metadata: VideoMetadata, fallback_fps: float
    ) -> float:
        """Calculate relative time in seconds since video started.

        Uses frame number and FPS when available, otherwise uses fallback_fps.
        Frame 1 corresponds to 0.0 seconds.
        """
        fps = metadata.fps if metadata.fps and metadata.fps != 0 else fallback_fps
        return (current_frame - 1) / fps

    def _evict_oldest_video(self) -> None:
        """Remove the oldest video stream data when MAX_VIDEOS is exceeded."""
        if len(self._event_logs) <= MAX_VIDEOS:
            return

        # Find the video with the oldest last access time
        oldest_video_id = min(self._last_access, key=self._last_access.get)

        # Remove all data for this video
        self._event_logs.pop(oldest_video_id, None)
        self._last_flush_frame.pop(oldest_video_id, None)
        self._frame_count.pop(oldest_video_id, None)
        self._last_access.pop(oldest_video_id, None)
        self._reference_timestamps.pop(oldest_video_id, None)

    def _remove_stale_events(
        self,
        event_log: Dict[int, DetectionEvent],
        current_frame: int,
        stale_frames: int,
        frame_threshold: int,
    ) -> List[DetectionEvent]:
        """Remove events that haven't been seen for stale_frames.

        Returns list of removed LOGGED events (events that met frame_threshold).
        These are "complete" events - objects that were tracked long enough
        to be logged and have now left the scene.
        """
        stale_tracker_ids = []
        complete_events = []

        for tracker_id, event in event_log.items():
            frames_since_seen = current_frame - event.last_seen_frame
            if frames_since_seen > stale_frames:
                stale_tracker_ids.append(tracker_id)
                # Only return logged events as "complete" - pending events are just discarded
                if event.frame_count >= frame_threshold:
                    complete_events.append(event)

        for tracker_id in stale_tracker_ids:
            del event_log[tracker_id]

        return complete_events

    def run(
        self,
        image: WorkflowImageData,
        detections: sv.Detections,
        frame_threshold: int,
        flush_interval: int,
        stale_frames: int,
        fallback_fps: float = 1.0,
        reference_timestamp: Optional[float] = None,
    ) -> BlockResult:
        """Process detections and update the event log.

        Args:
            image: Workflow image data containing video metadata.
            detections: Tracked detections with tracker_id from ByteTracker.
            frame_threshold: Minimum frames an object must be seen before logging.
            flush_interval: How often to run stale event cleanup.
            stale_frames: Remove events not seen for this many frames.
            fallback_fps: FPS to use when video metadata doesn't provide FPS.
            reference_timestamp: Optional Unix timestamp when video started. When provided,
                absolute timestamps are included in output.

        Returns:
            Dictionary containing event_log, detections, total_logged, and total_pending.
        """
        metadata = image.video_metadata
        video_id = metadata.video_identifier

        # Track global frame count and video access for eviction
        self._global_frame += 1
        self._last_access[video_id] = self._global_frame

        # Increment internal frame counter
        current_frame = self._frame_count.get(video_id, 0) + 1
        self._frame_count[video_id] = current_frame

        current_time = self._get_relative_time(current_frame, metadata, fallback_fps)

        # If reference_timestamp not provided, try to extract from video metadata
        effective_reference_timestamp = reference_timestamp
        if effective_reference_timestamp is None:
            # Check if we already have a stored reference timestamp for this video
            if video_id in self._reference_timestamps:
                effective_reference_timestamp = self._reference_timestamps[video_id]
            elif metadata.frame_timestamp is not None:
                # Calculate reference timestamp: frame_timestamp - relative_time
                # This gives us the timestamp when the video/stream started
                # frame_timestamp is a datetime object, so we need to convert to Unix timestamp
                frame_ts = metadata.frame_timestamp.timestamp()
                effective_reference_timestamp = frame_ts - current_time
                self._reference_timestamps[video_id] = effective_reference_timestamp
                logger.debug(
                    f"Extracted reference_timestamp for video {video_id}: {effective_reference_timestamp} "
                    f"(frame_timestamp={frame_ts}, relative_time={current_time})"
                )

        # Initialize event log for this video if needed
        event_log = self._event_logs.setdefault(video_id, {})

        # Evict oldest video if we've exceeded MAX_VIDEOS (after adding current video)
        self._evict_oldest_video()

        # Initialize last flush frame if not set
        if video_id not in self._last_flush_frame:
            self._last_flush_frame[video_id] = current_frame

        # Check if it's time to run cleanup
        complete_events_list = []
        last_flush = self._last_flush_frame.get(video_id, 0)
        if (current_frame - last_flush) >= flush_interval:
            complete_events_list = self._remove_stale_events(
                event_log, current_frame, stale_frames, frame_threshold
            )
            self._last_flush_frame[video_id] = current_frame

        # Format complete events
        complete_events = self._format_complete_events(
            complete_events_list, effective_reference_timestamp
        )

        # Process detections
        if detections.tracker_id is None or len(detections.tracker_id) == 0:
            # No tracked detections, return current log
            event_log_dict, total_logged, total_pending = self._format_event_log(
                event_log, frame_threshold, effective_reference_timestamp
            )
            return {
                OUTPUT_KEY: event_log_dict,
                DETECTIONS_OUTPUT_KEY: detections,
                "total_logged": total_logged,
                "total_pending": total_pending,
                "complete_events": complete_events,
            }

        # Get class names
        class_names = detections.data.get("class_name", [])
        if (
            len(class_names) == 0
            and hasattr(detections, "class_id")
            and detections.class_id is not None
        ):
            class_names = [f"class_{cid}" for cid in detections.class_id]

        # Update event log for each tracked detection
        for i, tracker_id in enumerate(detections.tracker_id):
            tracker_id = int(tracker_id)
            class_name = str(class_names[i]) if len(class_names) > 0 else "unknown"

            if tracker_id in event_log:
                # Update existing event
                event = event_log[tracker_id]
                event.last_seen_frame = current_frame
                event.last_seen_timestamp = current_time
                event.frame_count += 1

                # Mark as logged once threshold is reached
                if event.frame_count >= frame_threshold and not event.logged:
                    event.logged = True
                    logger.debug(
                        f"Object {tracker_id} ({event.class_name}) logged after {event.frame_count} frames"
                    )
            else:
                # Create new event
                event_log[tracker_id] = DetectionEvent(
                    tracker_id=tracker_id,
                    class_name=class_name,
                    first_seen_frame=current_frame,
                    first_seen_timestamp=current_time,
                    last_seen_frame=current_frame,
                    last_seen_timestamp=current_time,
                    frame_count=1,
                    logged=False,
                )

        event_log_dict, total_logged, total_pending = self._format_event_log(
            event_log, frame_threshold, effective_reference_timestamp
        )
        return {
            OUTPUT_KEY: event_log_dict,
            DETECTIONS_OUTPUT_KEY: detections,
            "total_logged": total_logged,
            "total_pending": total_pending,
            "complete_events": complete_events,
        }

    def _format_complete_events(
        self,
        complete_events: List[DetectionEvent],
        reference_timestamp: Optional[float] = None,
    ) -> Dict[str, Any]:
        """Format complete events for output.

        Args:
            complete_events: List of DetectionEvent objects that have completed (gone stale).
            reference_timestamp: Optional reference timestamp for absolute time calculation.

        Returns:
            Dictionary with tracker_id as key and event data as value.
        """
        formatted = {}
        for event in complete_events:
            event_data = asdict(event)
            del event_data["logged"]

            # Internal timestamps are relative (seconds since video start)
            # Rename to *_relative in output
            first_seen_relative = event_data.pop("first_seen_timestamp")
            last_seen_relative = event_data.pop("last_seen_timestamp")
            event_data["first_seen_relative"] = first_seen_relative
            event_data["last_seen_relative"] = last_seen_relative

            # Add absolute timestamps if reference_timestamp is provided
            if reference_timestamp is not None:
                event_data["first_seen_timestamp"] = (
                    first_seen_relative + reference_timestamp
                )
                event_data["last_seen_timestamp"] = (
                    last_seen_relative + reference_timestamp
                )

            formatted[str(event.tracker_id)] = event_data

        return formatted

    def _format_event_log(
        self,
        event_log: Dict[int, DetectionEvent],
        frame_threshold: int,
        reference_timestamp: Optional[float] = None,
    ) -> tuple:
        """Format the event log for output.

        Returns:
            Tuple of (event_log_dict, total_logged, total_pending)
        """
        logged_events = {}
        pending_events = {}

        for tracker_id, event in event_log.items():
            event_data = asdict(event)
            del event_data["logged"]

            # Internal timestamps are relative (seconds since video start)
            # Rename to *_relative in output
            first_seen_relative = event_data.pop("first_seen_timestamp")
            last_seen_relative = event_data.pop("last_seen_timestamp")
            event_data["first_seen_relative"] = first_seen_relative
            event_data["last_seen_relative"] = last_seen_relative

            # Add absolute timestamps if reference_timestamp is provided
            if reference_timestamp is not None:
                event_data["first_seen_timestamp"] = (
                    first_seen_relative + reference_timestamp
                )
                event_data["last_seen_timestamp"] = (
                    last_seen_relative + reference_timestamp
                )

            if event.frame_count >= frame_threshold:
                logged_events[str(tracker_id)] = event_data
            else:
                pending_events[str(tracker_id)] = event_data

        event_log_dict = {
            "logged": logged_events,
            "pending": pending_events,
        }

        return event_log_dict, len(logged_events), len(pending_events)
Functions
run
run(
    image,
    detections,
    frame_threshold,
    flush_interval,
    stale_frames,
    fallback_fps=1.0,
    reference_timestamp=None,
)

Process detections and update the event log.

Parameters:

Name Type Description Default
image WorkflowImageData

Workflow image data containing video metadata.

required
detections Detections

Tracked detections with tracker_id from ByteTracker.

required
frame_threshold int

Minimum frames an object must be seen before logging.

required
flush_interval int

How often to run stale event cleanup.

required
stale_frames int

Remove events not seen for this many frames.

required
fallback_fps float

FPS to use when video metadata doesn't provide FPS.

1.0
reference_timestamp Optional[float]

Optional Unix timestamp when video started. When provided, absolute timestamps are included in output.

None

Returns:

Type Description
BlockResult

Dictionary containing event_log, detections, total_logged, and total_pending.

Source code in inference/core/workflows/core_steps/analytics/detection_event_log/v1.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def run(
    self,
    image: WorkflowImageData,
    detections: sv.Detections,
    frame_threshold: int,
    flush_interval: int,
    stale_frames: int,
    fallback_fps: float = 1.0,
    reference_timestamp: Optional[float] = None,
) -> BlockResult:
    """Process detections and update the event log.

    Args:
        image: Workflow image data containing video metadata.
        detections: Tracked detections with tracker_id from ByteTracker.
        frame_threshold: Minimum frames an object must be seen before logging.
        flush_interval: How often to run stale event cleanup.
        stale_frames: Remove events not seen for this many frames.
        fallback_fps: FPS to use when video metadata doesn't provide FPS.
        reference_timestamp: Optional Unix timestamp when video started. When provided,
            absolute timestamps are included in output.

    Returns:
        Dictionary containing event_log, detections, total_logged, and total_pending.
    """
    metadata = image.video_metadata
    video_id = metadata.video_identifier

    # Track global frame count and video access for eviction
    self._global_frame += 1
    self._last_access[video_id] = self._global_frame

    # Increment internal frame counter
    current_frame = self._frame_count.get(video_id, 0) + 1
    self._frame_count[video_id] = current_frame

    current_time = self._get_relative_time(current_frame, metadata, fallback_fps)

    # If reference_timestamp not provided, try to extract from video metadata
    effective_reference_timestamp = reference_timestamp
    if effective_reference_timestamp is None:
        # Check if we already have a stored reference timestamp for this video
        if video_id in self._reference_timestamps:
            effective_reference_timestamp = self._reference_timestamps[video_id]
        elif metadata.frame_timestamp is not None:
            # Calculate reference timestamp: frame_timestamp - relative_time
            # This gives us the timestamp when the video/stream started
            # frame_timestamp is a datetime object, so we need to convert to Unix timestamp
            frame_ts = metadata.frame_timestamp.timestamp()
            effective_reference_timestamp = frame_ts - current_time
            self._reference_timestamps[video_id] = effective_reference_timestamp
            logger.debug(
                f"Extracted reference_timestamp for video {video_id}: {effective_reference_timestamp} "
                f"(frame_timestamp={frame_ts}, relative_time={current_time})"
            )

    # Initialize event log for this video if needed
    event_log = self._event_logs.setdefault(video_id, {})

    # Evict oldest video if we've exceeded MAX_VIDEOS (after adding current video)
    self._evict_oldest_video()

    # Initialize last flush frame if not set
    if video_id not in self._last_flush_frame:
        self._last_flush_frame[video_id] = current_frame

    # Check if it's time to run cleanup
    complete_events_list = []
    last_flush = self._last_flush_frame.get(video_id, 0)
    if (current_frame - last_flush) >= flush_interval:
        complete_events_list = self._remove_stale_events(
            event_log, current_frame, stale_frames, frame_threshold
        )
        self._last_flush_frame[video_id] = current_frame

    # Format complete events
    complete_events = self._format_complete_events(
        complete_events_list, effective_reference_timestamp
    )

    # Process detections
    if detections.tracker_id is None or len(detections.tracker_id) == 0:
        # No tracked detections, return current log
        event_log_dict, total_logged, total_pending = self._format_event_log(
            event_log, frame_threshold, effective_reference_timestamp
        )
        return {
            OUTPUT_KEY: event_log_dict,
            DETECTIONS_OUTPUT_KEY: detections,
            "total_logged": total_logged,
            "total_pending": total_pending,
            "complete_events": complete_events,
        }

    # Get class names
    class_names = detections.data.get("class_name", [])
    if (
        len(class_names) == 0
        and hasattr(detections, "class_id")
        and detections.class_id is not None
    ):
        class_names = [f"class_{cid}" for cid in detections.class_id]

    # Update event log for each tracked detection
    for i, tracker_id in enumerate(detections.tracker_id):
        tracker_id = int(tracker_id)
        class_name = str(class_names[i]) if len(class_names) > 0 else "unknown"

        if tracker_id in event_log:
            # Update existing event
            event = event_log[tracker_id]
            event.last_seen_frame = current_frame
            event.last_seen_timestamp = current_time
            event.frame_count += 1

            # Mark as logged once threshold is reached
            if event.frame_count >= frame_threshold and not event.logged:
                event.logged = True
                logger.debug(
                    f"Object {tracker_id} ({event.class_name}) logged after {event.frame_count} frames"
                )
        else:
            # Create new event
            event_log[tracker_id] = DetectionEvent(
                tracker_id=tracker_id,
                class_name=class_name,
                first_seen_frame=current_frame,
                first_seen_timestamp=current_time,
                last_seen_frame=current_frame,
                last_seen_timestamp=current_time,
                frame_count=1,
                logged=False,
            )

    event_log_dict, total_logged, total_pending = self._format_event_log(
        event_log, frame_threshold, effective_reference_timestamp
    )
    return {
        OUTPUT_KEY: event_log_dict,
        DETECTIONS_OUTPUT_KEY: detections,
        "total_logged": total_logged,
        "total_pending": total_pending,
        "complete_events": complete_events,
    }

core/workflows/core_steps/classical_cv/camera_focus

inference.core.workflows.core_steps.classical_cv.camera_focus.v1

Classes

Functions

calculate_brenner_measure

calculate_brenner_measure(
    input_image,
    text_color=(255, 255, 255),
    text_thickness=2,
)

Brenner's focus measure.

Parameters

input_image : np.ndarray The input image in grayscale. text_color : Tuple[int, int, int], optional The color of the text displaying the Brenner value, in BGR format. Default is white (255, 255, 255). text_thickness : int, optional The thickness of the text displaying the Brenner value. Default is 2.

Returns

Tuple[np.ndarray, float] The Brenner image and the Brenner value.

Source code in inference/core/workflows/core_steps/classical_cv/camera_focus/v1.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def calculate_brenner_measure(
    input_image: np.ndarray,
    text_color: Tuple[int, int, int] = (255, 255, 255),
    text_thickness: int = 2,
) -> Tuple[np.ndarray, float]:
    """
    Brenner's focus measure.

    Parameters
    ----------
    input_image : np.ndarray
        The input image in grayscale.
    text_color : Tuple[int, int, int], optional
        The color of the text displaying the Brenner value, in BGR format. Default is white (255, 255, 255).
    text_thickness : int, optional
        The thickness of the text displaying the Brenner value. Default is 2.

    Returns
    -------
    Tuple[np.ndarray, float]
        The Brenner image and the Brenner value.
    """
    # Convert image to grayscale if it has 3 channels
    if len(input_image.shape) == 3:
        input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)

    # Convert image to 16-bit integer format
    converted_image = input_image.astype(np.int16)

    # Get the dimensions of the image
    height, width = converted_image.shape

    # Initialize two matrices for horizontal and vertical focus measures
    horizontal_diff = np.zeros((height, width))
    vertical_diff = np.zeros((height, width))

    # Calculate horizontal and vertical focus measures
    horizontal_diff[:, : width - 2] = np.clip(
        converted_image[:, 2:] - converted_image[:, :-2], 0, None
    )
    vertical_diff[: height - 2, :] = np.clip(
        converted_image[2:, :] - converted_image[:-2, :], 0, None
    )

    # Calculate final focus measure
    focus_measure = np.max((horizontal_diff, vertical_diff), axis=0) ** 2

    # Convert focus measure matrix to 8-bit for visualization
    focus_measure_image = ((focus_measure / focus_measure.max()) * 255).astype(np.uint8)

    # Display the Brenner value on the top left of the image
    cv2.putText(
        focus_measure_image,
        f"Focus value: {focus_measure.mean():.2f}",
        (10, 30),
        cv2.FONT_HERSHEY_SIMPLEX,
        1,
        text_color,
        text_thickness,
    )

    return focus_measure_image, focus_measure.mean()

inference.core.workflows.core_steps.classical_cv.camera_focus.v2

Classes

Functions

visualize_tenengrad_measure

visualize_tenengrad_measure(
    input_image,
    underexposed_threshold=16,
    overexposed_threshold=239,
    show_zebra_warnings=True,
    grid_overlay="3x3",
    show_hud=True,
    show_focus_peaking=True,
    show_center_marker=True,
    detections=None,
)

Tenengrad focus measure with visualization overlay.

Uses Sobel operators to compute gradient magnitudes as a focus metric. Higher values indicate sharper/more in-focus images.

Returns the input image unchanged if no visualizations are enabled.

Source code in inference/core/workflows/core_steps/classical_cv/camera_focus/v2.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def visualize_tenengrad_measure(
    input_image: np.ndarray,
    underexposed_threshold: int = 16,
    overexposed_threshold: int = 239,
    show_zebra_warnings: bool = True,
    grid_overlay: str = "3x3",
    show_hud: bool = True,
    show_focus_peaking: bool = True,
    show_center_marker: bool = True,
    detections: Optional[sv.Detections] = None,
) -> Tuple[np.ndarray, float, List[Optional[float]]]:
    """
    Tenengrad focus measure with visualization overlay.

    Uses Sobel operators to compute gradient magnitudes as a focus metric.
    Higher values indicate sharper/more in-focus images.

    Returns the input image unchanged if no visualizations are enabled.
    """
    grid_divisions = GRID_DIVISIONS.get(grid_overlay, 0)
    any_visualization_enabled = (
        show_zebra_warnings
        or show_hud
        or show_focus_peaking
        or show_center_marker
        or grid_divisions > 0
    )

    gray, focus_measure, focus_value, bbox_focus_measures = _compute_tenengrad(
        input_image, detections
    )

    if not any_visualization_enabled:
        return input_image, focus_value, bbox_focus_measures

    if len(input_image.shape) == 3:
        output = input_image.copy()
    else:
        output = cv2.cvtColor(input_image, cv2.COLOR_GRAY2BGR)

    if show_zebra_warnings:
        output = _apply_zebra_warnings(
            output, gray, underexposed_threshold, overexposed_threshold
        )
    if show_focus_peaking:
        output = _apply_focus_peaking(output, focus_measure)
    if show_center_marker:
        output = _draw_center_marker(output)
    if grid_divisions > 0:
        output = _draw_grid(output, grid_divisions)
    if show_hud:
        output = _draw_hud_overlay(output, focus_value, gray, input_image)

    return output, focus_value, bbox_focus_measures

core/workflows/core_steps/classical_cv/contours

inference.core.workflows.core_steps.classical_cv.contours.v1

Classes

Functions

find_and_draw_contours

find_and_draw_contours(
    image, color=(255, 0, 255), thickness=3
)

Finds and draws contours on the image.

Parameters:

Name Type Description Default
image ndarray

Input thresholded image.

required
color tuple

Color of the contour lines in BGR. Defaults to purple (255, 0, 255).

(255, 0, 255)
thickness int

Thickness of the contour lines. Defaults to 3.

3

Returns:

Name Type Description
tuple Tuple[ndarray, int]

Image with contours drawn and number of contours.

Source code in inference/core/workflows/core_steps/classical_cv/contours/v1.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def find_and_draw_contours(
    image: np.ndarray, color: Tuple[int, int, int] = (255, 0, 255), thickness: int = 3
) -> Tuple[np.ndarray, int]:
    """
    Finds and draws contours on the image.

    Args:
        image (np.ndarray): Input thresholded image.
        color (tuple, optional): Color of the contour lines in BGR. Defaults to purple (255, 0, 255).
        thickness (int, optional): Thickness of the contour lines. Defaults to 3.

    Returns:
        tuple: Image with contours drawn and number of contours.
    """
    # If not in grayscale, convert to grayscale
    if len(image.shape) == 3 and image.shape[2] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Find contours
    contours, hierarchy = cv2.findContours(
        image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    # Draw contours on a copy of the original image
    contour_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    cv2.drawContours(contour_image, contours, -1, color, thickness)

    # Return the image with contours and the number of contours
    return contour_image, contours, hierarchy

core/workflows/core_steps/classical_cv/distance_measurement

inference.core.workflows.core_steps.classical_cv.distance_measurement.v1

Functions

has_overlap

has_overlap(bbox1, bbox2)

Check if two bounding boxes overlap.

Parameters:

Name Type Description Default
bbox1 Tuple[int, int, int, int]

A tuple of (x_min, y_min, x_max, y_max) for the first bounding box.

required
bbox2 Tuple[int, int, int, int]

A tuple of (x_min, y_min, x_max, y_max) for the second bounding box.

required

Returns:

Type Description
bool

True if the bounding boxes overlap, False otherwise.

Source code in inference/core/workflows/core_steps/classical_cv/distance_measurement/v1.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def has_overlap(
    bbox1: Tuple[int, int, int, int], bbox2: Tuple[int, int, int, int]
) -> bool:
    """
    Check if two bounding boxes overlap.

    Args:
        bbox1: A tuple of (x_min, y_min, x_max, y_max) for the first bounding box.
        bbox2: A tuple of (x_min, y_min, x_max, y_max) for the second bounding box.

    Returns:
        True if the bounding boxes overlap, False otherwise.
    """
    x1_min, y1_min, x1_max, y1_max = bbox1
    x2_min, y2_min, x2_max, y2_max = bbox2

    if x1_max < x2_min or x2_max < x1_min:
        return False
    if y1_max < y2_min or y2_max < y1_min:
        return False

    return True

core/workflows/core_steps/classical_cv/image_blur

inference.core.workflows.core_steps.classical_cv.image_blur.v1

Classes

Functions

apply_blur

apply_blur(image, blur_type, ksize=5)

Applies the specified blur to the image.

Parameters:

Name Type Description Default
image ndarray

Input image.

required
blur_type str

Type of blur ('average', 'gaussian', 'median', 'bilateral').

required
ksize int

Kernel size for the blur. Defaults to 5.

5

Returns:

Type Description
ndarray

np.ndarray: Blurred image.

Source code in inference/core/workflows/core_steps/classical_cv/image_blur/v1.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def apply_blur(image: np.ndarray, blur_type: str, ksize: int = 5) -> np.ndarray:
    """
    Applies the specified blur to the image.

    Args:
        image: Input image.
        blur_type (str): Type of blur ('average', 'gaussian', 'median', 'bilateral').
        ksize (int, optional): Kernel size for the blur. Defaults to 5.

    Returns:
        np.ndarray: Blurred image.
    """

    if blur_type == "average":
        blurred_image = cv2.blur(image, (ksize, ksize))
    elif blur_type == "gaussian":
        blurred_image = cv2.GaussianBlur(image, (ksize, ksize), 0)
    elif blur_type == "median":
        blurred_image = cv2.medianBlur(image, ksize)
    elif blur_type == "bilateral":
        blurred_image = cv2.bilateralFilter(image, ksize, 75, 75)
    else:
        raise ValueError(f"Unknown blur type: {blur_type}")

    return blurred_image

core/workflows/core_steps/classical_cv/mask_area_measurement

inference.core.workflows.core_steps.classical_cv.mask_area_measurement.v1

Functions

compute_detection_areas

compute_detection_areas(detections)

Compute the area of all detections in square pixels.

For bounding-box-only detections, areas are computed in a single vectorized operation. For detections with segmentation masks, the area is the count of non-zero mask pixels (via cv2.countNonZero). This correctly handles masks with holes — hole pixels are zero and are not counted. Falls back to the bounding box area when the mask pixel count is zero.

Parameters:

Name Type Description Default
detections Detections

A supervision Detections object.

required

Returns:

Type Description
List[float]

List of areas in square pixels, one per detection.

Source code in inference/core/workflows/core_steps/classical_cv/mask_area_measurement/v1.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def compute_detection_areas(detections: sv.Detections) -> List[float]:
    """Compute the area of all detections in square pixels.

    For bounding-box-only detections, areas are computed in a single vectorized
    operation. For detections with segmentation masks, the area is the count of
    non-zero mask pixels (via ``cv2.countNonZero``). This correctly handles masks
    with holes — hole pixels are zero and are not counted. Falls back to the
    bounding box area when the mask pixel count is zero.

    Args:
        detections: A supervision Detections object.

    Returns:
        List of areas in square pixels, one per detection.
    """
    n = len(detections)
    if n == 0:
        return []

    areas = []
    for i in range(n):
        if detections.mask is not None:
            count = cv.countNonZero(detections.mask[i].astype(np.uint8))
            if count > 0:
                areas.append(float(count))
                continue
        x1, y1, x2, y2 = detections.xyxy[i]
        areas.append(float((x2 - x1) * (y2 - y1)))

    return areas

core/workflows/core_steps/classical_cv/motion_detection

inference.core.workflows.core_steps.classical_cv.motion_detection.v1

Classes

Functions

clip_contours_to_contour

clip_contours_to_contour(contours, clip_contour)

Clip OpenCV contours to another contour and return clipped OpenCV contours.

Parameters:

Name Type Description Default
contours List[ndarray]

List of OpenCV contours, each as numpy array of shape (N, 1, 2)

required
clip_contour ndarray

Clip contour as numpy array of shape (M, 2) with xy points

required

Returns:

Type Description
List[ndarray]

List of clipped OpenCV contours as numpy arrays of shape (N, 1, 2).

List[ndarray]

Only includes contours that overlap with the clip contour.

Source code in inference/core/workflows/core_steps/classical_cv/motion_detection/v1.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def clip_contours_to_contour(
    contours: List[np.ndarray], clip_contour: np.ndarray
) -> List[np.ndarray]:
    """
    Clip OpenCV contours to another contour and return clipped OpenCV contours.

    Args:
        contours: List of OpenCV contours, each as numpy array of shape (N, 1, 2)
        clip_contour: Clip contour as numpy array of shape (M, 2) with xy points

    Returns:
        List of clipped OpenCV contours as numpy arrays of shape (N, 1, 2).
        Only includes contours that overlap with the clip contour.
    """

    clip_poly = Polygon(clip_contour)
    result = []

    for contour in contours:
        # Convert OpenCV contour (N, 1, 2) to xy points (N, 2)
        points = contour.reshape(-1, 2)

        if len(points) < 3:
            continue

        try:
            poly = Polygon(points)
            clipped = poly.intersection(clip_poly)

            if clipped.is_empty:
                continue

            # Extract coordinates based on geometry type
            if clipped.geom_type == "Polygon":
                coords = list(clipped.exterior.coords[:-1])
                if len(coords) >= 3:
                    result.append(list_to_contour(coords))

            elif clipped.geom_type == "MultiPolygon":
                for geom in clipped.geoms:
                    coords = list(geom.exterior.coords[:-1])
                    if len(coords) >= 3:
                        result.append(list_to_contour(coords))

        except Exception:
            # Silently skip contours that fail shapely operations
            # (e.g., self-intersecting polygons)
            continue

    return result

list_to_contour

list_to_contour(list_of_tuples)

Convert a list of (x, y) tuples to an OpenCV contour format.

Parameters:

Name Type Description Default
list_of_tuples List[Tuple]

List of coordinate tuples [(x1, y1), (x2, y2), ...]

required

Returns:

Type Description
ndarray

NumPy array of shape (N, 1, 2) suitable for OpenCV operations

Source code in inference/core/workflows/core_steps/classical_cv/motion_detection/v1.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def list_to_contour(list_of_tuples: List[Tuple]) -> np.ndarray:
    """
    Convert a list of (x, y) tuples to an OpenCV contour format.

    Args:
        list_of_tuples: List of coordinate tuples [(x1, y1), (x2, y2), ...]

    Returns:
        NumPy array of shape (N, 1, 2) suitable for OpenCV operations
    """
    points = np.array(
        [[int(xy[0]), int(xy[1])] for xy in list_of_tuples], dtype=np.int32
    )
    return points.reshape(-1, 1, 2)

core/workflows/core_steps/classical_cv/pixel_color_count

inference.core.workflows.core_steps.classical_cv.pixel_color_count.v1

Classes

Functions

count_specific_color_pixels

count_specific_color_pixels(image, target_color, tolerance)

Counts the number of pixels that match the target color within the given tolerance.

Parameters:

Name Type Description Default
image ndarray

Input image.

required
target_color Union[str, tuple]

Target color in hex format (e.g., '#431112') or BGR tuple (e.g., (18, 17, 67)).

required
tolerance int

Tolerance for color matching. Defaults to 10.

required

Returns:

Name Type Description
int int

Number of pixels that match the target color.

Source code in inference/core/workflows/core_steps/classical_cv/pixel_color_count/v1.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def count_specific_color_pixels(
    image: np.ndarray,
    target_color: Union[str, Tuple[int, int, int]],
    tolerance: int,
) -> int:
    """
    Counts the number of pixels that match the target color within the given tolerance.

    Args:
        image: Input image.
        target_color (Union[str, tuple]): Target color in hex format (e.g., '#431112') or BGR tuple (e.g., (18, 17, 67)).
        tolerance (int, optional): Tolerance for color matching. Defaults to 10.

    Returns:
        int: Number of pixels that match the target color.
    """
    target_color_bgr = convert_color_to_bgr_tuple(color=target_color)
    lower_bound = np.array(target_color_bgr) - tolerance
    upper_bound = np.array(target_color_bgr) + tolerance

    # Use vectorized comparison to directly create a mask and count non-zero elements
    mask = cv2.inRange(image, lower_bound, upper_bound)

    return int(cv2.countNonZero(mask))

core/workflows/core_steps/classical_cv/sift

inference.core.workflows.core_steps.classical_cv.sift.v1

Classes

Functions

apply_sift

apply_sift(image)

Applies SIFT to the image. Args: image: Input image. Returns: np.ndarray: Image with keypoints drawn. list: Keypoints detected. np.ndarray: Descriptors of the keypoints.

Source code in inference/core/workflows/core_steps/classical_cv/sift/v1.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def apply_sift(image: np.ndarray) -> (np.ndarray, list, np.ndarray):
    """
    Applies SIFT to the image.
    Args:
        image: Input image.
    Returns:
        np.ndarray: Image with keypoints drawn.
        list: Keypoints detected.
        np.ndarray: Descriptors of the keypoints.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    sift = cv2.SIFT_create()
    kp, des = sift.detectAndCompute(gray, None)
    img_with_kp = cv2.drawKeypoints(gray, kp, image)
    # Convert keypoints to the desired format
    keypoints = [
        {
            "pt": (point.pt[0], point.pt[1]),
            "size": point.size,
            "angle": point.angle,
            "response": point.response,
            "octave": point.octave,
            "class_id": point.class_id,
        }
        for point in kp
    ]
    return img_with_kp, keypoints, des

core/workflows/core_steps/classical_cv/sift_comparison

inference.core.workflows.core_steps.classical_cv.sift_comparison.v2

Classes

Functions

apply_sift

apply_sift(image, visualize=False)

Applies SIFT to the image. Args: image: Input image. visualize: Whether to visualize keypoints on the image. Returns: img_with_kp: Image with keypoints drawn (if visualize is True). kp: List of cv2.KeyPoint objects. keypoints_dicts: List of keypoints as dictionaries. des: Descriptors of the keypoints.

Source code in inference/core/workflows/core_steps/classical_cv/sift_comparison/v2.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def apply_sift(
    image: np.ndarray, visualize=False
) -> (Optional[np.ndarray], list, list, np.ndarray):
    """
    Applies SIFT to the image.
    Args:
        image: Input image.
        visualize: Whether to visualize keypoints on the image.
    Returns:
        img_with_kp: Image with keypoints drawn (if visualize is True).
        kp: List of cv2.KeyPoint objects.
        keypoints_dicts: List of keypoints as dictionaries.
        des: Descriptors of the keypoints.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    sift = cv2.SIFT_create()
    kp, des = sift.detectAndCompute(gray, None)
    img_with_kp = None
    if visualize:
        img_with_kp = cv2.drawKeypoints(gray, kp, None)
    # Convert keypoints to the desired format
    keypoints_dicts = [
        {
            "pt": (point.pt[0], point.pt[1]),
            "size": point.size,
            "angle": point.angle,
            "response": point.response,
            "octave": point.octave,
            "class_id": point.class_id,
        }
        for point in kp
    ]
    return img_with_kp, kp, keypoints_dicts, des

core/workflows/core_steps/classical_cv/size_measurement

inference.core.workflows.core_steps.classical_cv.size_measurement.v1

Functions

compute_aligned_dimensions

compute_aligned_dimensions(contour)

Compute the width and height of an object based on its contour, ensuring proper orientation.

This function: 1. Finds the minimum area rectangle that encloses the contour 2. Determines which edges correspond to width and height by analyzing their angles 3. Returns dimensions where width is the more horizontal edge and height is the more vertical edge

Parameters:

Name Type Description Default
contour ndarray

Array of points representing the object's contour

required

Returns:

Type Description
Tuple[float, float]

Tuple[float, float]: A tuple of (width_pixels, height_pixels) where: - width_pixels: Length of the more horizontal edge - height_pixels: Length of the more vertical edge

Note

The function uses angle analysis to ensure consistent width/height assignment regardless of the object's rotation. The edge closer to horizontal (0° or 180°) is always considered the width.

Source code in inference/core/workflows/core_steps/classical_cv/size_measurement/v1.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def compute_aligned_dimensions(contour: np.ndarray) -> Tuple[float, float]:
    """
    Compute the width and height of an object based on its contour, ensuring proper orientation.

    This function:
    1. Finds the minimum area rectangle that encloses the contour
    2. Determines which edges correspond to width and height by analyzing their angles
    3. Returns dimensions where width is the more horizontal edge and height is the more vertical edge

    Args:
        contour (np.ndarray): Array of points representing the object's contour

    Returns:
        Tuple[float, float]: A tuple of (width_pixels, height_pixels) where:
            - width_pixels: Length of the more horizontal edge
            - height_pixels: Length of the more vertical edge

    Note:
        The function uses angle analysis to ensure consistent width/height assignment
        regardless of the object's rotation. The edge closer to horizontal (0° or 180°)
        is always considered the width.
    """
    rect = cv.minAreaRect(contour)
    box = cv.boxPoints(rect)
    box = np.array(box, dtype=np.float32)

    edge1 = box[1] - box[0]
    edge2 = box[2] - box[1]

    len_edge1 = np.linalg.norm(edge1)
    len_edge2 = np.linalg.norm(edge2)

    angle1 = np.degrees(np.arctan2(edge1[1], edge1[0]))
    angle2 = np.degrees(np.arctan2(edge2[1], edge2[0]))

    h_score1 = horizontal_score(angle1)
    h_score2 = horizontal_score(angle2)

    if h_score1 < h_score2:
        width_pixels = len_edge1
        height_pixels = len_edge2
    else:
        width_pixels = len_edge2
        height_pixels = len_edge1

    return float(width_pixels), float(height_pixels)

get_detection_dimensions

get_detection_dimensions(detection, index)

Retrieve the width and height dimensions of a detected object in pixels.

Parameters:

Name Type Description Default
detection Detections

Detection object containing masks and/or bounding boxes

required
index int

Index of the specific detection to analyze

required

Returns:

Type Description
Tuple[Optional[float], Optional[float]]

Tuple[float, float]: A tuple of (width_pixels, height_pixels) where: - width_pixels: Width of the object in pixels - height_pixels: Height of the object in pixels

Notes

The function uses two methods to compute dimensions: 1. If a segmentation mask is available: - Extracts the largest contour from the mask - Uses compute_aligned_dimensions() to get orientation-aware measurements 2. If no mask is available: - Falls back to using the bounding box dimensions - Simply computes width and height as box edges

Source code in inference/core/workflows/core_steps/classical_cv/size_measurement/v1.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def get_detection_dimensions(
    detection: sv.Detections, index: int
) -> Tuple[Optional[float], Optional[float]]:
    """
    Retrieve the width and height dimensions of a detected object in pixels.

    Args:
        detection (sv.Detections): Detection object containing masks and/or bounding boxes
        index (int): Index of the specific detection to analyze

    Returns:
        Tuple[float, float]: A tuple of (width_pixels, height_pixels) where:
            - width_pixels: Width of the object in pixels
            - height_pixels: Height of the object in pixels

    Notes:
        The function uses two methods to compute dimensions:
        1. If a segmentation mask is available:
           - Extracts the largest contour from the mask
           - Uses compute_aligned_dimensions() to get orientation-aware measurements
        2. If no mask is available:
           - Falls back to using the bounding box dimensions
           - Simply computes width and height as box edges
    """
    if detection.mask is not None:
        mask = detection.mask[index].astype(np.uint8)
        contours, _ = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
        if contours:
            largest_contour = max(contours, key=cv.contourArea)
            if cv.contourArea(largest_contour) > 0:
                return compute_aligned_dimensions(largest_contour)

    else:
        bbox = detection.xyxy[index]
        w = bbox[2] - bbox[0]
        h = bbox[3] - bbox[1]
        return float(w), float(h)

    return None, None

horizontal_score

horizontal_score(angle)

Determine how close an angle is to horizontal (0 or 180 degrees). Lower score means more horizontal.

Source code in inference/core/workflows/core_steps/classical_cv/size_measurement/v1.py
128
129
130
131
132
133
134
def horizontal_score(angle: float) -> float:
    """
    Determine how close an angle is to horizontal (0 or 180 degrees).
    Lower score means more horizontal.
    """
    mod_angle = abs(angle % 180)
    return min(mod_angle, 180 - mod_angle)

parse_reference_dimensions

parse_reference_dimensions(reference_dimensions)

Parse reference dimensions from various input formats.

Source code in inference/core/workflows/core_steps/classical_cv/size_measurement/v1.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def parse_reference_dimensions(
    reference_dimensions: Union[str, Tuple[float, float], List[float]],
) -> Tuple[float, float]:
    """Parse reference dimensions from various input formats."""
    if isinstance(reference_dimensions, str):
        parts = reference_dimensions.split(",")
        if len(parts) != 2:
            raise ValueError(
                "reference_dimensions must be a string in the format 'width,height'"
            )
        try:
            reference_dimensions = [float(p.strip()) for p in parts]
        except ValueError:
            raise ValueError("Invalid format for reference_dimensions")

    if len(reference_dimensions) != 2:
        raise ValueError("reference_dimensions must have two values (width, height)")

    return tuple(reference_dimensions)

core/workflows/core_steps/classical_cv/threshold

inference.core.workflows.core_steps.classical_cv.threshold.v1

Classes

Functions

apply_thresholding

apply_thresholding(
    image, threshold_type, thresh_value, max_value
)

Applies the specified thresholding to the image.

Parameters:

Name Type Description Default
image ndarray

Input image in grayscale.

required
threshold_type str

Type of thresholding ('binary', 'binary_inv', 'trunc', 'tozero', 'tozero_inv', 'adaptive_mean', 'adaptive_gaussian', 'otsu').

required
thresh_value int

Threshold value.

required
max_value int

Maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.

required

Returns:

Type Description
ndarray

np.ndarray: Image with thresholding applied.

Source code in inference/core/workflows/core_steps/classical_cv/threshold/v1.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def apply_thresholding(
    image: np.ndarray, threshold_type: str, thresh_value: int, max_value: int
) -> np.ndarray:
    """
    Applies the specified thresholding to the image.

    Args:
        image (np.ndarray): Input image in grayscale.
        threshold_type (str): Type of thresholding ('binary', 'binary_inv', 'trunc', 'tozero', 'tozero_inv', 'adaptive_mean', 'adaptive_gaussian', 'otsu').
        thresh_value (int, optional): Threshold value.
        max_value (int, optional): Maximum value to use with the THRESH_BINARY and THRESH_BINARY_INV thresholding types.

    Returns:
        np.ndarray: Image with thresholding applied.
    """
    if threshold_type == "binary":
        _, thresh_image = cv2.threshold(
            image, thresh_value, max_value, cv2.THRESH_BINARY
        )
    elif threshold_type == "binary_inv":
        _, thresh_image = cv2.threshold(
            image, thresh_value, max_value, cv2.THRESH_BINARY_INV
        )
    elif threshold_type == "trunc":
        _, thresh_image = cv2.threshold(
            image, thresh_value, max_value, cv2.THRESH_TRUNC
        )
    elif threshold_type == "tozero":
        _, thresh_image = cv2.threshold(
            image, thresh_value, max_value, cv2.THRESH_TOZERO
        )
    elif threshold_type == "tozero_inv":
        _, thresh_image = cv2.threshold(
            image, thresh_value, max_value, cv2.THRESH_TOZERO_INV
        )
    elif threshold_type == "adaptive_mean":
        thresh_image = cv2.adaptiveThreshold(
            image, max_value, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2
        )
    elif threshold_type == "adaptive_gaussian":
        thresh_image = cv2.adaptiveThreshold(
            image,
            max_value,
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY,
            11,
            2,
        )
    elif threshold_type == "otsu":
        _, thresh_image = cv2.threshold(
            image, 0, max_value, cv2.THRESH_BINARY + cv2.THRESH_OTSU
        )
    else:
        raise ValueError(f"Unknown threshold type: {threshold_type}")

    return thresh_image

core/workflows/core_steps/common

inference.core.workflows.core_steps.common.utils

Classes

Functions

remove_unexpected_keys_from_dictionary

remove_unexpected_keys_from_dictionary(
    dictionary, expected_keys
)

This function mutates input dictionary

Source code in inference/core/workflows/core_steps/common/utils.py
430
431
432
433
434
435
436
437
438
def remove_unexpected_keys_from_dictionary(
    dictionary: dict,
    expected_keys: set,
) -> dict:
    """This function mutates input `dictionary`"""
    unexpected_keys = set(dictionary.keys()).difference(expected_keys)
    for unexpected_key in unexpected_keys:
        del dictionary[unexpected_key]
    return dictionary

core/workflows/core_steps/fusion/detections_list_rollup

inference.core.workflows.core_steps.fusion.detections_list_rollup.v1

Functions

merge_crop_predictions

merge_crop_predictions(
    parent_prediction,
    child_predictions,
    confidence_strategy="max",
    overlap_threshold=0.0,
    keypoint_merge_threshold=10.0,
)

Merge predictions from multiple crops back to parent image coordinates.

Parameters:

Name Type Description Default
parent_prediction

Supervision Detections object that defines the crop locations. Each detection in this prediction represents one crop region.

required
child_predictions List

List of Supervision Detections objects from crops. Order matches the detection order in parent_prediction.

required
confidence_strategy str

How to handle confidence when merging overlaps. Options: "max", "mean", "min"

'max'
overlap_threshold float

Minimum IoU/overlap ratio to merge detections (0.0 to 1.0). - 0.0: Only merge if detections touch or overlap at all (default) - >0.0: Only merge if overlap ratio exceeds this threshold - 1.0: Only merge completely overlapping detections

0.0
keypoint_merge_threshold float

Maximum distance in pixels to merge keypoints (default: 10). For keypoint detections, merges detections if their average keypoint distance is below this threshold.

10.0

Returns:

Type Description
Tuple

Tuple of (detections, crop_zones):

Tuple
  • detections: Detections object with merged predictions in parent image coordinates. Works for both instance segmentation (with masks) and object detection (without masks).
Tuple
  • crop_zones: List of lists of (x, y) tuples. Each inner list defines the rectangular zone boundary of a crop in parent image coordinates as 4 corner points.
Source code in inference/core/workflows/core_steps/fusion/detections_list_rollup/v1.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
def merge_crop_predictions(
    parent_prediction,
    child_predictions: List,
    confidence_strategy: str = "max",
    overlap_threshold: float = 0.0,
    keypoint_merge_threshold: float = 10.0,
) -> Tuple:
    """
    Merge predictions from multiple crops back to parent image coordinates.

    Args:
        parent_prediction: Supervision Detections object that defines the crop locations.
                          Each detection in this prediction represents one crop region.
        child_predictions: List of Supervision Detections objects from crops.
                          Order matches the detection order in parent_prediction.
        confidence_strategy: How to handle confidence when merging overlaps.
                           Options: "max", "mean", "min"
        overlap_threshold: Minimum IoU/overlap ratio to merge detections (0.0 to 1.0).
                         - 0.0: Only merge if detections touch or overlap at all (default)
                         - >0.0: Only merge if overlap ratio exceeds this threshold
                         - 1.0: Only merge completely overlapping detections
        keypoint_merge_threshold: Maximum distance in pixels to merge keypoints (default: 10).
                                For keypoint detections, merges detections if their average
                                keypoint distance is below this threshold.

    Returns:
        Tuple of (detections, crop_zones):
        - detections: Detections object with merged predictions in parent image coordinates.
                     Works for both instance segmentation (with masks) and object detection (without masks).
        - crop_zones: List of lists of (x, y) tuples. Each inner list defines the rectangular
                     zone boundary of a crop in parent image coordinates as 4 corner points.
    """
    if len(parent_prediction) != len(child_predictions):
        raise ValueError(
            f"Number of detections in parent_prediction ({len(parent_prediction)}) "
            f"must match number of child predictions ({len(child_predictions)})"
        )

    # Extract parent image shape from parent prediction's data
    # root_parent_dimensions is a list of tuples, one per detection (all should be the same)
    root_parent_dims = parent_prediction.data.get("root_parent_dimensions")

    if root_parent_dims is None or len(root_parent_dims) == 0:
        raise ValueError(
            "parent_prediction must have 'root_parent_dimensions' in its data attribute"
        )

    # Get the first tuple (all should be identical for the same parent image)
    parent_image_shape = root_parent_dims[0]

    # Build crop zones list - one zone per crop/child prediction
    crop_zones = []
    for i in range(len(parent_prediction)):
        crop_bbox = parent_prediction.xyxy[i]  # [x_min, y_min, x_max, y_max]
        x_min, y_min, x_max, y_max = (
            crop_bbox[0],
            crop_bbox[1],
            crop_bbox[2],
            crop_bbox[3],
        )

        # Create zone as list of 4 corner points: top-left, top-right, bottom-right, bottom-left
        zone = [
            (float(x_min), float(y_min)),  # top-left
            (float(x_max), float(y_min)),  # top-right
            (float(x_max), float(y_max)),  # bottom-right
            (float(x_min), float(y_max)),  # bottom-left
        ]
        crop_zones.append(zone)

    # Check if we have instance segmentation (with masks) or object detection (without masks)
    has_masks = False
    is_keypoint_detection = False
    for child_pred in child_predictions:
        if child_pred.mask is not None and len(child_pred.mask) > 0:
            has_masks = True
            break

    for child_pred in child_predictions:
        # Check for keypoint detection
        if "prediction_type" in child_pred.data:
            pred_type = child_pred.data["prediction_type"]
            if isinstance(pred_type, np.ndarray):
                if len(pred_type) > 0 and pred_type[0] == "keypoint-detection":
                    is_keypoint_detection = True
                    break
            elif pred_type == "keypoint-detection":
                is_keypoint_detection = True
                break

    # Group predictions by class
    class_predictions = {}

    # Iterate through each crop region and its corresponding child predictions
    for i, child_pred in enumerate(child_predictions):
        # Get crop location from parent prediction
        crop_bbox = parent_prediction.xyxy[i]  # [x_min, y_min, x_max, y_max]
        x_min, y_min = int(crop_bbox[0]), int(crop_bbox[1])

        # Process each detection in the child prediction
        for j in range(len(child_pred)):
            class_id = child_pred.class_id[j]
            confidence = child_pred.confidence[j]

            # Prepare keypoint data if present
            keypoint_data = {}
            if is_keypoint_detection and "keypoints_xy" in child_pred.data:
                # Transform keypoint coordinates from crop to parent space
                keypoints_xy = child_pred.data["keypoints_xy"][
                    j
                ]  # Shape: (num_keypoints, 2)

                # Transform coordinates
                transformed_keypoints = []
                for kp in keypoints_xy:
                    transformed_kp = [kp[0] + x_min, kp[1] + y_min]
                    transformed_keypoints.append(transformed_kp)

                keypoint_data["keypoints_xy"] = transformed_keypoints

                # Copy other keypoint data
                if "keypoints_class_name" in child_pred.data:
                    keypoint_data["keypoints_class_name"] = child_pred.data[
                        "keypoints_class_name"
                    ][j]
                if "keypoints_class_id" in child_pred.data:
                    keypoint_data["keypoints_class_id"] = child_pred.data[
                        "keypoints_class_id"
                    ][j]
                if "keypoints_confidence" in child_pred.data:
                    keypoint_data["keypoints_confidence"] = child_pred.data[
                        "keypoints_confidence"
                    ][j]

            # Collect per-detection data fields to preserve individual detection metadata
            # This is crucial for preserving class_name and other fields when multiple
            # detections have the same class_id but different values
            detection_data = {}
            for key in child_pred.data.keys():
                if key not in [
                    "detection_id",
                    "parent_id",
                    "inference_id",
                    "keypoints_xy",
                    "keypoints_class_name",
                    "keypoints_class_id",
                    "keypoints_confidence",
                ]:
                    if j < len(child_pred.data[key]):
                        detection_data[key] = child_pred.data[key][j]

            if has_masks and child_pred.mask is not None:
                # Instance segmentation - transform mask
                mask = child_pred.mask[j]
                transformed_mask = _transform_mask_to_parent(
                    mask, x_min, y_min, parent_image_shape
                )

                # Store prediction with transformed mask
                if class_id not in class_predictions:
                    class_predictions[class_id] = []

                class_predictions[class_id].append(
                    {
                        "mask": transformed_mask,
                        "confidence": confidence,
                        "class_id": class_id,
                        "bbox": None,  # Will compute from mask
                        "keypoint_data": keypoint_data,
                        "detection_data": detection_data,  # Store per-detection metadata
                    }
                )
            else:
                # Object detection - transform bounding box
                bbox = child_pred.xyxy[j]  # [x_min, y_min, x_max, y_max]
                transformed_bbox = np.array(
                    [bbox[0] + x_min, bbox[1] + y_min, bbox[2] + x_min, bbox[3] + y_min]
                )

                # Store prediction with transformed bbox
                if class_id not in class_predictions:
                    class_predictions[class_id] = []

                class_predictions[class_id].append(
                    {
                        "bbox": transformed_bbox,
                        "confidence": confidence,
                        "class_id": class_id,
                        "mask": None,
                        "keypoint_data": keypoint_data,
                        "detection_data": detection_data,  # Store per-detection metadata
                    }
                )

    # Merge overlapping predictions for each class
    merged_masks = []
    merged_bboxes = []
    merged_confidences = []
    merged_class_ids = []

    # Collect all data field names from child predictions
    all_data_keys = set()
    for child_pred in child_predictions:
        all_data_keys.update(child_pred.data.keys())

    # Initialize lists for each data field
    merged_data = {
        key: []
        for key in all_data_keys
        if key
        not in [
            "keypoints_xy",
            "keypoints_class_name",
            "keypoints_class_id",
            "keypoints_confidence",
        ]
    }

    # Collect keypoint data separately
    all_keypoints_data = {
        "keypoints_xy": [],
        "keypoints_class_name": [],
        "keypoints_class_id": [],
        "keypoints_confidence": [],
    }

    # Build mapping from class_id to typical data values
    class_id_to_data = {}
    for child_pred in child_predictions:
        for i in range(len(child_pred)):
            class_id = child_pred.class_id[i]
            if class_id not in class_id_to_data:
                class_id_to_data[class_id] = {}
                # Store sample values for this class_id (except ID fields and keypoint fields)
                for key in child_pred.data.keys():
                    if key not in [
                        "detection_id",
                        "parent_id",
                        "inference_id",
                        "keypoints_xy",
                        "keypoints_class_name",
                        "keypoints_class_id",
                        "keypoints_confidence",
                    ]:
                        if key in child_pred.data and i < len(child_pred.data[key]):
                            class_id_to_data[class_id][key] = child_pred.data[key][i]

    # Get a sample inference_id and parent_id from the first child prediction if available
    sample_inference_id = None
    sample_parent_id = None
    if len(child_predictions) > 0 and len(child_predictions[0]) > 0:
        if "inference_id" in child_predictions[0].data:
            sample_inference_id = child_predictions[0].data["inference_id"][0]
        if "parent_id" in child_predictions[0].data:
            sample_parent_id = child_predictions[0].data["parent_id"][0]

    for class_id, preds in class_predictions.items():
        if is_keypoint_detection:
            # For keypoint detection, merge based on keypoint proximity
            merged_preds = _merge_keypoint_detections(
                preds, confidence_strategy, keypoint_merge_threshold
            )
        elif has_masks:
            merged_preds = _merge_overlapping_masks(
                preds, confidence_strategy, overlap_threshold
            )
        else:
            merged_preds = _merge_overlapping_bboxes(
                preds, confidence_strategy, overlap_threshold
            )

        for pred in merged_preds:
            if has_masks:
                merged_masks.append(pred["mask"])
            else:
                # For non-mask detections, collect bboxes
                if "bbox" in pred and pred["bbox"] is not None:
                    merged_bboxes.append(pred["bbox"])
            merged_confidences.append(pred["confidence"])
            merged_class_ids.append(pred["class_id"])

            # Collect keypoint data if present
            if "keypoint_data" in pred and pred["keypoint_data"]:
                kp_data = pred["keypoint_data"]
                all_keypoints_data["keypoints_xy"].append(kp_data.get("keypoints_xy"))
                all_keypoints_data["keypoints_class_name"].append(
                    kp_data.get("keypoints_class_name")
                )
                all_keypoints_data["keypoints_class_id"].append(
                    kp_data.get("keypoints_class_id")
                )
                all_keypoints_data["keypoints_confidence"].append(
                    kp_data.get("keypoints_confidence")
                )

            # Add data fields for this detection
            for key in all_data_keys:
                # Skip keypoint fields as they're handled separately
                if key in [
                    "keypoints_xy",
                    "keypoints_class_name",
                    "keypoints_class_id",
                    "keypoints_confidence",
                ]:
                    continue

                if key == "detection_id":
                    # Generate new UUID for merged detection
                    merged_data[key].append(str(uuid.uuid4()))
                elif key == "parent_id":
                    # Use sample parent_id or generate new one
                    merged_data[key].append(
                        sample_parent_id if sample_parent_id else str(uuid.uuid4())
                    )
                elif key == "inference_id":
                    # Use the same inference_id as inputs (they're from same inference batch)
                    merged_data[key].append(
                        sample_inference_id
                        if sample_inference_id
                        else str(uuid.uuid4())
                    )
                elif key == "root_parent_dimensions":
                    # Add the parent image shape as a list [height, width]
                    merged_data[key].append(list(parent_image_shape))
                elif key == "parent_dimensions":
                    # Parent dimensions should be same as root_parent_dimensions for merged results
                    merged_data[key].append(list(parent_image_shape))
                elif key == "image_dimensions":
                    # Image dimensions for this detection
                    merged_data[key].append(list(parent_image_shape))
                elif key == "root_parent_coordinates":
                    # Root parent coordinates [y, x] - should be [0, 0] for the root
                    if (
                        pred["class_id"] in class_id_to_data
                        and key in class_id_to_data[pred["class_id"]]
                    ):
                        merged_data[key].append(class_id_to_data[pred["class_id"]][key])
                    else:
                        merged_data[key].append([0, 0])
                elif key == "parent_coordinates":
                    # Parent coordinates [y, x]
                    if (
                        pred["class_id"] in class_id_to_data
                        and key in class_id_to_data[pred["class_id"]]
                    ):
                        merged_data[key].append(class_id_to_data[pred["class_id"]][key])
                    else:
                        merged_data[key].append([0, 0])
                elif key == "root_parent_id":
                    # Root parent ID
                    if (
                        pred["class_id"] in class_id_to_data
                        and key in class_id_to_data[pred["class_id"]]
                    ):
                        merged_data[key].append(class_id_to_data[pred["class_id"]][key])
                    else:
                        merged_data[key].append("image")
                elif key == "prediction_type":
                    # Prediction type should be 'instance-segmentation'
                    merged_data[key].append("instance-segmentation")
                else:
                    # For other fields like class_name, check pred dict first (per-detection data)
                    # then fall back to class_id_to_data (class-level defaults)
                    if key in pred.get("detection_data", {}):
                        merged_data[key].append(pred["detection_data"][key])
                    elif (
                        pred["class_id"] in class_id_to_data
                        and key in class_id_to_data[pred["class_id"]]
                    ):
                        merged_data[key].append(class_id_to_data[pred["class_id"]][key])
                    else:
                        merged_data[key].append(None)

    if not merged_confidences:
        # Return empty detections if no detections
        return Detections.empty(), crop_zones

    # Convert to numpy arrays
    merged_confidences_array = np.array(merged_confidences, dtype=np.float32)
    merged_class_ids_array = np.array(merged_class_ids, dtype=int)

    if has_masks:
        # Instance segmentation - stack masks and compute bounding boxes
        merged_masks_array = np.stack(merged_masks, axis=0)

        # Compute bounding boxes from masks
        xyxy = []
        for mask in merged_masks_array:
            rows, cols = np.where(mask)
            if len(rows) > 0:
                x_min, x_max = cols.min(), cols.max()
                y_min, y_max = rows.min(), rows.max()
                xyxy.append([x_min, y_min, x_max + 1, y_max + 1])
            else:
                xyxy.append([0, 0, 0, 0])

        xyxy_array = np.array(xyxy, dtype=np.float32)

        # Create Detections object with masks
        result = Detections(
            xyxy=xyxy_array,
            mask=merged_masks_array,
            confidence=merged_confidences_array,
            class_id=merged_class_ids_array,
        )
    else:
        # Object detection - use bounding boxes directly
        if merged_bboxes:
            xyxy_array = np.array(merged_bboxes, dtype=np.float32)
        else:
            # Shouldn't happen, but handle edge case
            xyxy_array = np.zeros((len(merged_confidences), 4), dtype=np.float32)

        # Create Detections object without masks
        result = Detections(
            xyxy=xyxy_array,
            confidence=merged_confidences_array,
            class_id=merged_class_ids_array,
        )

    # Convert data fields to numpy arrays with proper dtypes
    for key, values in merged_data.items():
        if key in [
            "class_name",
            "prediction_type",
            "detection_id",
            "parent_id",
            "inference_id",
            "root_parent_id",
        ]:
            # String fields - use 'U' dtype (Unicode strings), not np.str_
            result.data[key] = np.array(values, dtype=str)
        elif key in [
            "root_parent_dimensions",
            "parent_dimensions",
            "image_dimensions",
            "root_parent_coordinates",
            "parent_coordinates",
        ]:
            # Array/coordinate fields - convert to numpy arrays of integers
            result.data[key] = np.array(values, dtype=int)
        else:
            # Other fields - store as is
            result.data[key] = np.array(values)

    # Add keypoint data if it exists
    if is_keypoint_detection:
        if all_keypoints_data["keypoints_xy"]:
            result.data["keypoints_xy"] = np.array(
                all_keypoints_data["keypoints_xy"], dtype=object
            )
        if all_keypoints_data["keypoints_class_name"]:
            result.data["keypoints_class_name"] = np.array(
                all_keypoints_data["keypoints_class_name"], dtype=object
            )
        if all_keypoints_data["keypoints_class_id"]:
            result.data["keypoints_class_id"] = np.array(
                all_keypoints_data["keypoints_class_id"], dtype=object
            )
        if all_keypoints_data["keypoints_confidence"]:
            result.data["keypoints_confidence"] = np.array(
                all_keypoints_data["keypoints_confidence"], dtype=object
            )

    return result, crop_zones

core/workflows/core_steps/fusion/detections_stitch

inference.core.workflows.core_steps.fusion.detections_stitch.v1

Classes

Functions

move_detections

move_detections(detections, offset, resolution_wh)

Copied from: https://github.com/roboflow/supervision/blob/5123085037ec594524fc8f9d9b71b1cd9f487e8d/supervision/detection/tools/inference_slicer.py#L17-L16 to avoid fragile contract with supervision, as this function is not element of public API.

Source code in inference/core/workflows/core_steps/fusion/detections_stitch/v1.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def move_detections(
    detections: sv.Detections,
    offset: Optional[np.ndarray],
    resolution_wh: Optional[Tuple[int, int]],
) -> sv.Detections:
    """
    Copied from: https://github.com/roboflow/supervision/blob/5123085037ec594524fc8f9d9b71b1cd9f487e8d/supervision/detection/tools/inference_slicer.py#L17-L16
    to avoid fragile contract with supervision, as this function is not element of public
    API.
    """
    if len(detections) == 0:
        return detections
    if offset is None:
        raise ValueError("To move non-empty detections offset is needed, but not given")
    detections.xyxy = move_boxes(xyxy=detections.xyxy, offset=offset)
    if detections.mask is not None:
        if resolution_wh is None:
            raise ValueError(
                "To move non-empty detections with segmentation mask, resolution_wh is needed, but not given."
            )
        detections.mask = move_masks(
            masks=detections.mask, offset=offset, resolution_wh=resolution_wh
        )
    return detections

core/workflows/core_steps

inference.core.workflows.core_steps.loader

Classes

core/workflows/core_steps/models/foundation/anthropic_claude

inference.core.workflows.core_steps.models.foundation.anthropic_claude.v3

Classes

Functions

execute_claude_request

execute_claude_request(
    roboflow_api_key,
    anthropic_api_key,
    system_prompt,
    messages,
    model_version,
    max_tokens,
    temperature,
    extended_thinking,
    thinking_budget_tokens,
)

Route to proxied or direct execution based on API key format.

Source code in inference/core/workflows/core_steps/models/foundation/anthropic_claude/v3.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def execute_claude_request(
    roboflow_api_key: Optional[str],
    anthropic_api_key: str,
    system_prompt: Optional[str],
    messages: List[dict],
    model_version: str,
    max_tokens: Optional[int],
    temperature: Optional[float],
    extended_thinking: Optional[bool],
    thinking_budget_tokens: Optional[int],
) -> str:
    """Route to proxied or direct execution based on API key format."""
    if anthropic_api_key.startswith(("rf_key:account", "rf_key:user:")):
        return _execute_proxied_claude_request(
            roboflow_api_key=roboflow_api_key,
            anthropic_api_key=anthropic_api_key,
            system_prompt=system_prompt,
            messages=messages,
            model_version=model_version,
            max_tokens=max_tokens,
            temperature=temperature,
            extended_thinking=extended_thinking,
            thinking_budget_tokens=thinking_budget_tokens,
        )
    else:
        return _execute_direct_claude_request(
            anthropic_api_key=anthropic_api_key,
            system_prompt=system_prompt,
            messages=messages,
            model_version=model_version,
            max_tokens=max_tokens,
            temperature=temperature,
            extended_thinking=extended_thinking,
            thinking_budget_tokens=thinking_budget_tokens,
        )

core/workflows/core_steps/models/foundation/gaze

inference.core.workflows.core_steps.models.foundation.gaze.v1

Classes

GazeBlockV1

Bases: WorkflowBlock

Source code in inference/core/workflows/core_steps/models/foundation/gaze/v1.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
class GazeBlockV1(WorkflowBlock):
    def __init__(
        self,
        model_manager: ModelManager,
        api_key: Optional[str],
        step_execution_mode: StepExecutionMode,
    ):
        self._model_manager = model_manager
        self._api_key = api_key
        self._step_execution_mode = step_execution_mode

    @classmethod
    def get_init_parameters(cls) -> List[str]:
        return ["model_manager", "api_key", "step_execution_mode"]

    @classmethod
    def get_manifest(cls) -> Type[WorkflowBlockManifest]:
        return BlockManifest

    def run(
        self,
        images: Batch[WorkflowImageData],
        do_run_face_detection: bool,
    ) -> BlockResult:
        if self._step_execution_mode is StepExecutionMode.LOCAL:
            return self.run_locally(
                images=images,
                do_run_face_detection=do_run_face_detection,
            )
        elif self._step_execution_mode is StepExecutionMode.REMOTE:
            return self.run_remotely(
                images=images,
                do_run_face_detection=do_run_face_detection,
            )
        else:
            raise ValueError(
                f"Unsupported step execution mode: {self._step_execution_mode}"
            )

    def run_remotely(
        self,
        images: Batch[WorkflowImageData],
        do_run_face_detection: bool,
    ) -> BlockResult:
        api_url = (
            LOCAL_INFERENCE_API_URL
            if WORKFLOWS_REMOTE_API_TARGET != "hosted"
            else HOSTED_CORE_MODEL_URL
        )
        client = InferenceHTTPClient(
            api_url=api_url,
            api_key=self._api_key,
        )
        if WORKFLOWS_REMOTE_API_TARGET == "hosted":
            client.select_api_v0()
        else:
            client.select_api_v1()

        inference_images = [i.base64_image for i in images]
        predictions = client.detect_gazes(inference_input=inference_images)

        if not isinstance(predictions, list):
            predictions = [predictions]

        # Process remote predictions into the expected format
        return self._process_remote_predictions(
            images=images,
            predictions=predictions,
        )

    def _process_remote_predictions(
        self,
        images: Batch[WorkflowImageData],
        predictions: List[dict],
    ) -> BlockResult:
        """Process predictions from remote execution into the expected format."""
        face_predictions = []
        yaw_degrees = []
        pitch_degrees = []

        for single_image, prediction in zip(images, predictions):
            height, width = single_image.numpy_image.shape[:2]

            image_face_preds = {
                "predictions": [],
                "image": {"width": width, "height": height},
            }
            batch_yaw = []
            batch_pitch = []

            for pred in prediction.get("predictions", []):
                face = pred.get("face", {})

                face_pred = {
                    "x": face.get("x", 0),
                    "y": face.get("y", 0),
                    "width": face.get("width", 0),
                    "height": face.get("height", 0),
                    "confidence": face.get("confidence", 0),
                    "class": "face",
                    "class_id": 0,
                    "keypoints": [
                        {
                            "x": l.get("x", 0),
                            "y": l.get("y", 0),
                            "confidence": face.get("confidence", 0),
                            "class": str(i),
                            "class_id": i,
                        }
                        for i, l in enumerate(face.get("landmarks", []))
                    ],
                }

                image_face_preds["predictions"].append(face_pred)

                # Store angles in degrees (remote already returns radians)
                batch_yaw.append(pred.get("yaw", 0) * 180 / np.pi)
                batch_pitch.append(pred.get("pitch", 0) * 180 / np.pi)

            face_predictions.append(image_face_preds)
            yaw_degrees.append(batch_yaw)
            pitch_degrees.append(batch_pitch)

        # Process predictions
        face_preds = convert_inference_detections_batch_to_sv_detections(
            face_predictions
        )

        # Add keypoints to supervision detections
        for prediction, detections in zip(face_predictions, face_preds):
            add_inference_keypoints_to_sv_detections(
                inference_prediction=prediction["predictions"],
                detections=detections,
            )

        face_preds = attach_prediction_type_info_to_sv_detections_batch(
            predictions=face_preds,
            prediction_type="facial-landmark",
        )
        face_preds = attach_parents_coordinates_to_batch_of_sv_detections(
            images=images,
            predictions=face_preds,
        )

        return [
            {
                "face_predictions": face_pred,
                "yaw_degrees": yaw,
                "pitch_degrees": pitch,
            }
            for face_pred, yaw, pitch in zip(face_preds, yaw_degrees, pitch_degrees)
        ]

    def run_locally(
        self,
        images: Batch[WorkflowImageData],
        do_run_face_detection: bool,
    ) -> BlockResult:
        predictions = []

        for single_image in images:
            inference_request = GazeDetectionInferenceRequest(
                image=single_image.to_inference_format(numpy_preferred=True),
                do_run_face_detection=do_run_face_detection,
                api_key=self._api_key,
            )
            gaze_model_id = load_core_model(
                model_manager=self._model_manager,
                inference_request=inference_request,
                core_model="gaze",
            )
            prediction = self._model_manager.infer_from_request_sync(
                gaze_model_id, inference_request
            )
            predictions.append(prediction)

        # Convert predictions to supervision format and get angles
        face_preds, yaw_degrees, pitch_degrees = (
            convert_gaze_detections_to_sv_detections_and_angles(
                images=images,
                gaze_predictions=predictions,
            )
        )

        return [
            {
                "face_predictions": face_pred,
                "yaw_degrees": yaw,
                "pitch_degrees": pitch,
            }
            for face_pred, yaw, pitch in zip(face_preds, yaw_degrees, pitch_degrees)
        ]

Functions

convert_gaze_detections_to_sv_detections_and_angles

convert_gaze_detections_to_sv_detections_and_angles(
    images, gaze_predictions
)

Convert gaze detection results to supervision detections and angle lists.

Source code in inference/core/workflows/core_steps/models/foundation/gaze/v1.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def convert_gaze_detections_to_sv_detections_and_angles(
    images: Batch[WorkflowImageData],
    gaze_predictions: List[dict],
) -> Tuple[List[sv.Detections], List[List[float]], List[List[float]]]:
    """Convert gaze detection results to supervision detections and angle lists."""
    face_predictions = []
    yaw_degrees = []
    pitch_degrees = []

    for single_image, predictions in zip(images, gaze_predictions):
        height, width = single_image.numpy_image.shape[:2]

        # Format predictions for this image
        image_face_preds = {
            "predictions": [],
            "image": {"width": width, "height": height},
        }
        batch_yaw = []
        batch_pitch = []

        for p in predictions:  # predictions is already a list
            p_dict = p.model_dump(by_alias=True, exclude_none=True)
            for pred in p_dict["predictions"]:
                face = pred["face"]

                # Face detection with landmarks
                face_pred = {
                    "x": face["x"],
                    "y": face["y"],
                    "width": face["width"],
                    "height": face["height"],
                    "confidence": face["confidence"],
                    "class": "face",
                    "class_id": 0,
                    "keypoints": [
                        {
                            "x": l["x"],
                            "y": l["y"],
                            "confidence": face["confidence"],
                            "class": str(i),
                            "class_id": i,
                        }
                        for i, l in enumerate(face["landmarks"])
                    ],
                }

                image_face_preds["predictions"].append(face_pred)

                # Store angles in degrees
                batch_yaw.append(pred["yaw"] * 180 / np.pi)
                batch_pitch.append(pred["pitch"] * 180 / np.pi)

        face_predictions.append(image_face_preds)
        yaw_degrees.append(batch_yaw)
        pitch_degrees.append(batch_pitch)

    # Process predictions
    face_preds = convert_inference_detections_batch_to_sv_detections(face_predictions)

    # Add keypoints to supervision detections
    for prediction, detections in zip(face_predictions, face_preds):
        add_inference_keypoints_to_sv_detections(
            inference_prediction=prediction["predictions"],
            detections=detections,
        )

    face_preds = attach_prediction_type_info_to_sv_detections_batch(
        predictions=face_preds,
        prediction_type="facial-landmark",
    )
    face_preds = attach_parents_coordinates_to_batch_of_sv_detections(
        images=images,
        predictions=face_preds,
    )

    return face_preds, yaw_degrees, pitch_degrees

core/workflows/core_steps/models/foundation/google_gemini

inference.core.workflows.core_steps.models.foundation.google_gemini.v3

Classes

Functions

execute_gemini_request

execute_gemini_request(
    roboflow_api_key, google_api_key, prompt, model_version
)

Route to proxied or direct execution based on API key format.

Source code in inference/core/workflows/core_steps/models/foundation/google_gemini/v3.py
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def execute_gemini_request(
    roboflow_api_key: Optional[str],
    google_api_key: str,
    prompt: dict,
    model_version: str,
) -> str:
    """Route to proxied or direct execution based on API key format."""
    if google_api_key.startswith(("rf_key:account", "rf_key:user:")):
        return _execute_proxied_gemini_request(
            roboflow_api_key=roboflow_api_key,
            google_api_key=google_api_key,
            prompt=prompt,
            model_version=model_version,
        )
    else:
        return _execute_direct_gemini_request(
            google_api_key=google_api_key,
            prompt=prompt,
            model_version=model_version,
        )

core/workflows/core_steps/models/foundation/openai

inference.core.workflows.core_steps.models.foundation.openai.v3

Classes

Functions

inference.core.workflows.core_steps.models.foundation.openai.v4

Classes

Functions

core/workflows/core_steps/models/foundation/segment_anything3_3d

inference.core.workflows.core_steps.models.foundation.segment_anything3_3d.v1

Classes

Functions

extract_masks_from_input

extract_masks_from_input(mask_input)

Extract binary masks from sv.Detections, pass through other formats.

Source code in inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/v1.py
229
230
231
232
233
234
235
236
237
def extract_masks_from_input(mask_input: Any) -> Any:
    """Extract binary masks from sv.Detections, pass through other formats."""
    if isinstance(mask_input, sv.Detections):
        if len(mask_input) == 0:
            raise ValueError("sv.Detections contains no detections.")
        if mask_input.mask is not None and len(mask_input.mask) > 0:
            return list(mask_input.mask)
        raise ValueError("sv.Detections has no mask data.")
    return mask_input

core/workflows/core_steps/models/foundation/stability_ai/inpainting

inference.core.workflows.core_steps.models.foundation.stability_ai.inpainting.v1

Credits to: https://github.com/Fafruch for origin idea

Classes

core/workflows/core_steps/models/foundation/stability_ai/outpainting

inference.core.workflows.core_steps.models.foundation.stability_ai.outpainting.v1

Credits to: https://github.com/Fafruch for origin idea

Classes

core/workflows/core_steps/sinks/email_notification

inference.core.workflows.core_steps.sinks.email_notification.v2

Classes

Functions

apply_operations_to_message_parameters

apply_operations_to_message_parameters(
    message_parameters, message_parameters_operations
)

Apply per-parameter operation chains to message parameter values.

For each parameter in message_parameters, if operations are defined in message_parameters_operations for that parameter name, the operations are applied in order (e.g. ToString, StringToUpperCase, LookupTable). Parameters with no operations are returned unchanged.

Supports all value types, including WorkflowImageData: image operations such as ExtractImageProperty, ConvertImageToBase64, and ConvertImageToJPEG can be used to transform image parameters before they are serialized or interpolated into the message.

Returns:

Type Description
Dict[str, Any]

A dict with the same keys as message_parameters and values that are

Dict[str, Any]

either the original value (no operations) or the result of the

Dict[str, Any]

operations chain.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
def apply_operations_to_message_parameters(
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
) -> Dict[str, Any]:
    """
    Apply per-parameter operation chains to message parameter values.

    For each parameter in message_parameters, if operations are defined in
    message_parameters_operations for that parameter name, the operations are
    applied in order (e.g. ToString, StringToUpperCase, LookupTable).
    Parameters with no operations are returned unchanged.

    Supports all value types, including WorkflowImageData: image operations
    such as ExtractImageProperty, ConvertImageToBase64, and ConvertImageToJPEG
    can be used to transform image parameters before they are serialized or
    interpolated into the message.

    Returns:
        A dict with the same keys as message_parameters and values that are
        either the original value (no operations) or the result of the
        operations chain.
    """
    parameters_values = {}
    for parameter_name in message_parameters:
        parameter_value = message_parameters[parameter_name]

        operations = message_parameters_operations.get(parameter_name)
        if not operations:
            parameters_values[parameter_name] = parameter_value
            continue

        operations_chain = build_operations_chain(operations=operations)
        parameters_values[parameter_name] = operations_chain(
            parameter_value, global_parameters={}
        )

    return parameters_values

format_email_message

format_email_message(
    message,
    message_parameters,
    message_parameters_operations,
)

Format email message by replacing parameter placeholders with actual values.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
def format_email_message(
    message: str,
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
) -> str:
    """Format email message by replacing parameter placeholders with actual values."""
    matching_parameters = PARAMETER_REGEX.findall(message)
    parameters_to_get_values = {
        p[1] for p in matching_parameters if p[1] in message_parameters
    }

    parameters_values = apply_operations_to_message_parameters(
        message_parameters=message_parameters,
        message_parameters_operations=message_parameters_operations,
    )

    parameter_to_placeholders = defaultdict(list)
    for placeholder, parameter_name in matching_parameters:
        if parameter_name not in parameters_to_get_values:
            continue
        parameter_to_placeholders[parameter_name].append(placeholder)
    for parameter_name, placeholders in parameter_to_placeholders.items():
        for placeholder in placeholders:
            message = message.replace(
                placeholder, str(parameters_values[parameter_name])
            )
    return message

format_email_message_html_with_images

format_email_message_html_with_images(
    message,
    message_parameters,
    message_parameters_operations,
)

Format email message as HTML with inline images.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
def format_email_message_html_with_images(
    message: str,
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
) -> Tuple[str, Dict[str, bytes]]:
    """Format email message as HTML with inline images."""
    matching_parameters = PARAMETER_REGEX.findall(message)
    parameters_to_get_values = {
        p[1] for p in matching_parameters if p[1] in message_parameters
    }

    parameters_values = {}
    image_attachments = {}

    for parameter_name in parameters_to_get_values:
        parameter_value = message_parameters[parameter_name]

        # Apply operations if any
        operations = message_parameters_operations.get(parameter_name)
        if operations:
            operations_chain = build_operations_chain(operations=operations)
            parameter_value = operations_chain(parameter_value, global_parameters={})

        if isinstance(parameter_value, WorkflowImageData):
            # Convert to JPEG and create CID
            jpeg_bytes = encode_image_to_jpeg_bytes(parameter_value.numpy_image)
            cid = f"image_{parameter_name}"
            image_attachments[cid] = jpeg_bytes
            parameters_values[parameter_name] = (
                f'<img src="cid:{cid}" alt="{parameter_name}" style="max-width: 600px; height: auto;">'
            )
        else:
            import html

            parameters_values[parameter_name] = html.escape(str(parameter_value))

    # Replace placeholders
    parameter_to_placeholders = defaultdict(list)
    for placeholder, parameter_name in matching_parameters:
        if parameter_name in parameters_to_get_values:
            parameter_to_placeholders[parameter_name].append(placeholder)

    html_message = message
    for parameter_name, placeholders in parameter_to_placeholders.items():
        for placeholder in placeholders:
            html_message = html_message.replace(
                placeholder, str(parameters_values[parameter_name])
            )

    # Convert newlines to <br> tags for HTML
    html_message = html_message.replace("\n", "<br>\n")

    return html_message, image_attachments

process_attachments

process_attachments(attachments)

Process attachments dict to convert WorkflowImageData to JPEG bytes. Returns a dict with filename -> bytes mapping.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
def process_attachments(attachments: Dict[str, Any]) -> Dict[str, bytes]:
    """
    Process attachments dict to convert WorkflowImageData to JPEG bytes.
    Returns a dict with filename -> bytes mapping.
    """
    processed = {}
    for filename, value in attachments.items():
        if isinstance(value, WorkflowImageData):
            # Convert image to JPEG bytes
            numpy_image = value.numpy_image
            jpeg_bytes = encode_image_to_jpeg_bytes(numpy_image)
            processed[filename] = jpeg_bytes
        elif isinstance(value, bytes):
            # Already bytes, use as-is
            processed[filename] = value
        elif isinstance(value, str):
            # String data (e.g., CSV content)
            processed[filename] = value.encode("utf-8")
        else:
            # Fallback: convert to string then bytes
            processed[filename] = str(value).encode("utf-8")
    return processed

send_email_using_smtp_server_v2

send_email_using_smtp_server_v2(
    sender_email,
    receiver_email,
    cc_receiver_email,
    bcc_receiver_email,
    subject,
    message,
    attachments,
    smtp_server,
    smtp_port,
    sender_email_password,
    inline_images,
    is_html,
)

V2-specific SMTP email sender with inline image support. This function is used only by v2 block and does not modify v1 behavior.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
def send_email_using_smtp_server_v2(
    sender_email: str,
    receiver_email: List[str],
    cc_receiver_email: Optional[List[str]],
    bcc_receiver_email: Optional[List[str]],
    subject: str,
    message: str,
    attachments: Dict[str, bytes],
    smtp_server: str,
    smtp_port: int,
    sender_email_password: str,
    inline_images: Dict[str, bytes],
    is_html: bool,
) -> Tuple[bool, str]:
    """
    V2-specific SMTP email sender with inline image support.
    This function is used only by v2 block and does not modify v1 behavior.
    """
    try:
        _send_email_using_smtp_server_v2(
            sender_email=sender_email,
            receiver_email=receiver_email,
            cc_receiver_email=cc_receiver_email,
            bcc_receiver_email=bcc_receiver_email,
            subject=subject,
            message=message,
            attachments=attachments,
            smtp_server=smtp_server,
            smtp_port=smtp_port,
            sender_email_password=sender_email_password,
            inline_images=inline_images,
            is_html=is_html,
        )
        return False, "Notification sent successfully"
    except Exception as error:
        logging.warning(
            f"Could not send e-mail using custom SMTP server. Error: {str(error)}"
        )
        return True, f"Failed to send e-mail. Internal error details: {error}"

send_email_via_roboflow_proxy

send_email_via_roboflow_proxy(
    roboflow_api_key,
    receiver_email,
    cc_receiver_email,
    bcc_receiver_email,
    subject,
    message,
    message_parameters,
    message_parameters_operations,
    attachments,
)

Send email through Roboflow's proxy service.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
def send_email_via_roboflow_proxy(
    roboflow_api_key: str,
    receiver_email: List[str],
    cc_receiver_email: Optional[List[str]],
    bcc_receiver_email: Optional[List[str]],
    subject: str,
    message: str,
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
    attachments: Dict[str, Any],
) -> Tuple[bool, str]:
    """Send email through Roboflow's proxy service."""
    from inference.core.exceptions import (
        RoboflowAPIForbiddenError,
        RoboflowAPIUnsuccessfulRequestError,
    )

    # Custom error handler that preserves the API's error message
    def handle_email_proxy_error(status_code: int, http_error: Exception) -> None:
        """Extract and preserve the actual error message from the API response."""
        try:
            response = http_error.response
            error_data = response.json()
            # API returns 'details' field with the actual message, 'error' is generic
            # Prioritize 'details' over 'error' for more specific messages
            api_error_message = (
                error_data.get("details") or error_data.get("error") or str(http_error)
            )
        except Exception:
            api_error_message = str(http_error)

        # Raise appropriate exception with the actual API error message
        if status_code == 403:
            raise RoboflowAPIForbiddenError(api_error_message) from http_error
        elif status_code == 413:
            raise RoboflowAPIUnsuccessfulRequestError(api_error_message) from http_error
        elif status_code == 429:
            raise RoboflowAPIUnsuccessfulRequestError(api_error_message) from http_error
        else:
            raise RoboflowAPIUnsuccessfulRequestError(api_error_message) from http_error

    # Map status codes to our custom handler
    custom_error_handlers = {
        403: lambda e: handle_email_proxy_error(403, e),
        413: lambda e: handle_email_proxy_error(413, e),
        429: lambda e: handle_email_proxy_error(429, e),
    }

    try:
        message_parameters_after_operations = apply_operations_to_message_parameters(
            message_parameters=message_parameters,
            message_parameters_operations=message_parameters_operations,
        )
        # Serialize any WorkflowImageData objects to base64 strings for JSON transmission
        serialized_parameters = serialize_image_data_parameters(
            message_parameters_after_operations
        )

        payload = {
            "receiver_email": receiver_email,
            "subject": subject,
            "message": message,
            "message_parameters": serialized_parameters,
        }

        if cc_receiver_email:
            payload["cc_receiver_email"] = cc_receiver_email
        if bcc_receiver_email:
            payload["bcc_receiver_email"] = bcc_receiver_email
        if attachments:
            # Process attachments: convert images to JPEG bytes, then base64 encode
            import base64

            processed_attachments = {}
            for filename, value in attachments.items():
                if isinstance(value, WorkflowImageData):
                    # Convert image to JPEG bytes
                    numpy_image = value.numpy_image
                    jpeg_bytes = encode_image_to_jpeg_bytes(numpy_image)
                    # Ensure filename has .jpg extension
                    if not filename.lower().endswith((".jpg", ".jpeg")):
                        filename = f"{filename}.jpg"
                    # Base64 encode for JSON transmission
                    processed_attachments[filename] = base64.b64encode(
                        jpeg_bytes
                    ).decode("utf-8")
                elif isinstance(value, bytes):
                    # Already bytes, base64 encode
                    processed_attachments[filename] = base64.b64encode(value).decode(
                        "utf-8"
                    )
                elif isinstance(value, str):
                    # String data (e.g., CSV content), base64 encode
                    processed_attachments[filename] = base64.b64encode(
                        value.encode("utf-8")
                    ).decode("utf-8")
                else:
                    # Fallback: convert to string then bytes then base64
                    processed_attachments[filename] = base64.b64encode(
                        str(value).encode("utf-8")
                    ).decode("utf-8")
            payload["attachments"] = processed_attachments

        endpoint = "apiproxy/email"

        response_data = post_to_roboflow_api(
            endpoint=endpoint,
            api_key=roboflow_api_key,
            payload=payload,
            http_errors_handlers=custom_error_handlers,
        )

        return False, "Notification sent successfully via Roboflow proxy"
    except RoboflowAPIForbiddenError as error:
        # Handle 403 errors (whitelist violations)
        error_message = str(error)
        logging.warning(
            f"Email rejected by proxy due to access restrictions: {error_message}"
        )

        # Check if it's a workspace member restriction
        # The API returns detailed error messages about non-workspace members
        if "non-workspace members" in error_message.lower():
            return True, (
                "To prevent spam, you can only send emails to members of your Roboflow Workspace via the Roboflow Managed API Key. "
                "Add this email to your Workspace or switch to sending via your own SMTP server."
            )
        else:
            return True, f"Failed to send email: access forbidden. {error_message}"
    except RoboflowAPIUnsuccessfulRequestError as error:
        # Handle rate limiting (429) and other API errors
        error_message = str(error)
        logging.warning(f"Email proxy API error: {error_message}")

        # Check for payload too large (413)
        if (
            "413" in error_message
            or "payload too large" in error_message.lower()
            or "too large" in error_message.lower()
        ):
            return True, (
                "Failed to send email: attachment size exceeds the 5MB limit. "
                "For image attachments, use the Image Preprocessing block to resize images before sending. "
                "For other attachments (like CSV files), reduce the file size or send smaller data."
            )
        # Check if it's a rate limit error
        elif "rate limit" in error_message.lower():
            return True, (
                "Failed to send email: rate limit exceeded. "
                "The workspace has exceeded its email sending limits. "
                "Please wait before sending more emails or contact support to increase your limits."
            )
        elif "credits exceeded" in error_message.lower():
            return True, (
                "Failed to send email: workspace credits exceeded. "
                "Please add more credits to your workspace to continue sending emails."
            )
        else:
            return True, f"Failed to send email via proxy. {error_message}"
    except Exception as error:
        logging.warning(
            f"Could not send e-mail via Roboflow proxy. Error: {str(error)}"
        )
        return True, f"Failed to send e-mail via proxy. Internal error details: {error}"

serialize_image_data

serialize_image_data(value)

Serialize WorkflowImageData objects to base64 strings for JSON transmission. Returns the value unchanged if it's not a WorkflowImageData object.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
def serialize_image_data(value: Any) -> Any:
    """
    Serialize WorkflowImageData objects to base64 strings for JSON transmission.
    Returns the value unchanged if it's not a WorkflowImageData object.
    """
    if isinstance(value, WorkflowImageData):
        # Get the base64 representation of the image
        base64_image = value.base64_image
        if base64_image:
            return base64_image
        # If no base64 available, try to convert numpy array
        numpy_image = value.numpy_image
        if numpy_image is not None:
            import cv2

            _, buffer = cv2.imencode(".jpg", numpy_image)
            import base64

            return base64.b64encode(buffer).decode("utf-8")
    elif isinstance(value, dict):
        return {k: serialize_image_data(v) for k, v in value.items()}
    elif isinstance(value, list):
        return [serialize_image_data(item) for item in value]
    return value

serialize_image_data_parameters

serialize_image_data_parameters(message_parameters)

Convert any WorkflowImageData objects in message_parameters to base64 strings so they can be serialized to JSON for the API call.

Source code in inference/core/workflows/core_steps/sinks/email_notification/v2.py
746
747
748
749
750
751
752
753
def serialize_image_data_parameters(
    message_parameters: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Convert any WorkflowImageData objects in message_parameters to base64 strings
    so they can be serialized to JSON for the API call.
    """
    return {k: serialize_image_data(v) for k, v in message_parameters.items()}

core/workflows/core_steps/sinks/roboflow/dataset_upload

inference.core.workflows.core_steps.sinks.roboflow.dataset_upload.v1


  • WARNING! *

This module contains the utility functions used by RoboflowDatasetUploadBlockV2.

We do not recommend making multiple blocks dependent on the same code, but the change between v1 and v2 was basically the default value of some parameter - hence we decided not to replicate the code.

If you need to modify this module beware that you may introduce change to RoboflowDatasetUploadBlockV2! If that happens, probably that's the time to disentangle those blocks and copy the code.

Classes

core/workflows/core_steps/sinks/twilio/sms

inference.core.workflows.core_steps.sinks.twilio.sms.v2

Classes

Functions

format_message

format_message(
    message,
    message_parameters,
    message_parameters_operations,
)

Format SMS/MMS message by replacing parameter placeholders with actual values.

Returns:

Type Description
str

Tuple of (formatted_message, needs_mms) where needs_mms is True if message

bool

exceeds SMS character limit and should be sent as MMS.

Source code in inference/core/workflows/core_steps/sinks/twilio/sms/v2.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def format_message(
    message: str,
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
) -> Tuple[str, bool]:
    """Format SMS/MMS message by replacing parameter placeholders with actual values.

    Returns:
        Tuple of (formatted_message, needs_mms) where needs_mms is True if message
        exceeds SMS character limit and should be sent as MMS.
    """
    matching_parameters = PARAMETER_REGEX.findall(message)
    parameters_to_get_values = {
        p[1] for p in matching_parameters if p[1] in message_parameters
    }
    parameters_values = {}
    for parameter_name in parameters_to_get_values:
        parameter_value = message_parameters[parameter_name]
        operations = message_parameters_operations.get(parameter_name)
        if not operations:
            parameters_values[parameter_name] = parameter_value
            continue
        operations_chain = build_operations_chain(operations=operations)
        parameters_values[parameter_name] = operations_chain(
            parameter_value, global_parameters={}
        )
    parameter_to_placeholders = defaultdict(list)
    for placeholder, parameter_name in matching_parameters:
        if parameter_name not in parameters_to_get_values:
            continue
        parameter_to_placeholders[parameter_name].append(placeholder)
    for parameter_name, placeholders in parameter_to_placeholders.items():
        for placeholder in placeholders:
            message = message.replace(
                placeholder, str(parameters_values[parameter_name])
            )

    # Determine if MMS is needed (message exceeds SMS limit)
    needs_mms = len(message) > SMS_CHAR_LIMIT

    # Truncate at MMS limit if necessary
    if len(message) > MMS_CHAR_LIMIT:
        truncated_message = message[: MMS_CHAR_LIMIT - 1 - len(TRUNCATION_MARKER)]
        message = f"{truncated_message} {TRUNCATION_MARKER}"

    return message, needs_mms

process_media_urls_for_twilio

process_media_urls_for_twilio(media_url)

Process media URLs for Twilio MMS. Converts WorkflowImageData to temporary public URLs.

Source code in inference/core/workflows/core_steps/sinks/twilio/sms/v2.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
def process_media_urls_for_twilio(
    media_url: Union[str, List[Union[str, WorkflowImageData]], WorkflowImageData]
) -> Optional[List[str]]:
    """
    Process media URLs for Twilio MMS.
    Converts WorkflowImageData to temporary public URLs.
    """
    if isinstance(media_url, WorkflowImageData):
        url = _upload_image_to_ephemeral_host(media_url)
        if url:
            return [url]
        logging.warning("Failed to upload WorkflowImageData to temporary storage")
        return None
    elif isinstance(media_url, str):
        return [media_url]
    elif isinstance(media_url, list):
        result = []
        for item in media_url:
            if isinstance(item, WorkflowImageData):
                url = _upload_image_to_ephemeral_host(item)
                if url:
                    result.append(url)
                else:
                    logging.warning(
                        "Failed to upload WorkflowImageData to temporary storage"
                    )
            else:
                result.append(item)
        return result if result else None
    return None

send_sms_using_twilio_client

send_sms_using_twilio_client(
    client,
    message,
    sender_number,
    receiver_number,
    media_urls,
)

Send SMS/MMS using Twilio client directly.

Source code in inference/core/workflows/core_steps/sinks/twilio/sms/v2.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def send_sms_using_twilio_client(
    client: Client,
    message: str,
    sender_number: str,
    receiver_number: str,
    media_urls: Optional[List[str]],
) -> Tuple[bool, str]:
    """Send SMS/MMS using Twilio client directly."""
    try:
        message_params = {
            "body": message,
            "from_": sender_number,
            "to": receiver_number,
        }
        if media_urls:
            message_params["media_url"] = media_urls

        client.messages.create(**message_params)
        return False, "Notification sent successfully"
    except Exception as error:
        logging.warning(f"Could not send Twilio SMS notification. Error: {str(error)}")
        return (
            True,
            f"Failed to send Twilio SMS notification. Internal error details: {error}",
        )

send_sms_via_roboflow_proxy

send_sms_via_roboflow_proxy(
    roboflow_api_key,
    receiver_number,
    message,
    message_parameters,
    message_parameters_operations,
    media_url,
)

Send SMS/MMS through Roboflow's proxy service.

Source code in inference/core/workflows/core_steps/sinks/twilio/sms/v2.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
def send_sms_via_roboflow_proxy(
    roboflow_api_key: str,
    receiver_number: str,
    message: str,
    message_parameters: Dict[str, Any],
    message_parameters_operations: Dict[str, List[AllOperationsType]],
    media_url: Optional[Union[str, List[str], WorkflowImageData]],
) -> Tuple[bool, str]:
    """Send SMS/MMS through Roboflow's proxy service."""

    # Custom error handler that preserves the API's error message
    def handle_sms_proxy_error(status_code: int, http_error: Exception) -> None:
        """Extract and preserve the actual error message from the API response."""
        try:
            response = http_error.response
            error_data = response.json()
            api_error_message = (
                error_data.get("details") or error_data.get("error") or str(http_error)
            )
        except Exception:
            api_error_message = str(http_error)

        if status_code == 403:
            raise RoboflowAPIForbiddenError(api_error_message) from http_error
        elif status_code == 429:
            raise RoboflowAPIUnsuccessfulRequestError(api_error_message) from http_error
        else:
            raise RoboflowAPIUnsuccessfulRequestError(api_error_message) from http_error

    custom_error_handlers = {
        403: lambda e: handle_sms_proxy_error(403, e),
        429: lambda e: handle_sms_proxy_error(429, e),
    }

    try:
        # Format message client-side before sending to proxy
        formatted_message, needs_mms = format_message(
            message=message,
            message_parameters=message_parameters,
            message_parameters_operations=message_parameters_operations,
        )

        payload = {
            "receiver_number": receiver_number,
            "message": formatted_message,
        }

        # Serialize media - separates URLs from base64 data
        has_media = False
        if media_url is not None:
            media_urls, media_base64 = serialize_media_for_api(media_url)
            if media_urls:
                payload["media_urls"] = media_urls
                has_media = True
            if media_base64:
                payload["media_base64"] = media_base64
                has_media = True

        # If message exceeds SMS limit but no media, tell server to force MMS
        if needs_mms and not has_media:
            payload["force_mms"] = True

        endpoint = "apiproxy/twilio"

        response_data = post_to_roboflow_api(
            endpoint=endpoint,
            api_key=roboflow_api_key,
            payload=payload,
            http_errors_handlers=custom_error_handlers,
        )

        return False, "Notification sent successfully via Roboflow proxy"
    except RoboflowAPIForbiddenError as error:
        error_message = str(error)
        logging.warning(
            f"SMS rejected by proxy due to access restrictions: {error_message}"
        )
        return True, f"Failed to send SMS: access forbidden. {error_message}"
    except RoboflowAPIUnsuccessfulRequestError as error:
        error_message = str(error)
        logging.warning(f"SMS proxy API error: {error_message}")

        if "rate limit" in error_message.lower():
            return True, (
                "Failed to send SMS: rate limit exceeded. "
                "The workspace has exceeded its SMS sending limits. "
                "Please wait before sending more messages."
            )
        elif "credits exceeded" in error_message.lower():
            return True, (
                "Failed to send SMS: workspace credits exceeded. "
                "Please add more credits to your workspace to continue sending messages."
            )
        else:
            return True, f"Failed to send SMS via proxy. {error_message}"
    except Exception as error:
        logging.warning(f"Could not send SMS via Roboflow proxy. Error: {str(error)}")
        return True, f"Failed to send SMS via proxy. Internal error details: {error}"

serialize_media_for_api

serialize_media_for_api(media_url)

Serialize media for API transmission. Separates URL-based media from base64 image data.

Returns:

Type Description
Optional[List[str]]

Tuple of (media_urls, media_base64) where:

Optional[List[Dict[str, str]]]
  • media_urls: List of string URLs
Tuple[Optional[List[str]], Optional[List[Dict[str, str]]]]
  • media_base64: List of {"base64": str, "mimeType": str} objects
Source code in inference/core/workflows/core_steps/sinks/twilio/sms/v2.py
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def serialize_media_for_api(
    media_url: Union[str, List[str], WorkflowImageData, None]
) -> Tuple[Optional[List[str]], Optional[List[Dict[str, str]]]]:
    """
    Serialize media for API transmission.
    Separates URL-based media from base64 image data.

    Returns:
        Tuple of (media_urls, media_base64) where:
        - media_urls: List of string URLs
        - media_base64: List of {"base64": str, "mimeType": str} objects
    """
    if media_url is None:
        return None, None

    media_urls: List[str] = []
    media_base64: List[Dict[str, str]] = []

    items = [media_url] if not isinstance(media_url, list) else media_url

    for item in items:
        if isinstance(item, WorkflowImageData):
            # Convert to base64 JPEG
            jpeg_bytes = encode_image_to_jpeg_bytes(item.numpy_image)
            media_base64.append(
                {
                    "base64": base64.b64encode(jpeg_bytes).decode("utf-8"),
                    "mimeType": "image/jpeg",
                }
            )
        elif isinstance(item, str):
            media_urls.append(item)

    return (media_urls if media_urls else None, media_base64 if media_base64 else None)

core/workflows/core_steps/transformations/detections_merge

inference.core.workflows.core_steps.transformations.detections_merge.v1

Functions

calculate_union_bbox

calculate_union_bbox(detections)

Calculate a single bounding box that contains all input detections.

Source code in inference/core/workflows/core_steps/transformations/detections_merge/v1.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def calculate_union_bbox(detections: sv.Detections) -> np.ndarray:
    """Calculate a single bounding box that contains all input detections."""
    if len(detections) == 0:
        return np.array([], dtype=np.float32).reshape(0, 4)

    # Get all bounding boxes
    xyxy = detections.xyxy

    # Calculate the union by taking min/max coordinates
    x1 = np.min(xyxy[:, 0])
    y1 = np.min(xyxy[:, 1])
    x2 = np.max(xyxy[:, 2])
    y2 = np.max(xyxy[:, 3])

    return np.array([[x1, y1, x2, y2]])

get_lowest_confidence_index

get_lowest_confidence_index(detections)

Get the index of the detection with the lowest confidence.

Source code in inference/core/workflows/core_steps/transformations/detections_merge/v1.py
134
135
136
137
138
def get_lowest_confidence_index(detections: sv.Detections) -> int:
    """Get the index of the detection with the lowest confidence."""
    if detections.confidence is None:
        return 0
    return int(np.argmin(detections.confidence))

core/workflows/core_steps/transformations/image_slicer

inference.core.workflows.core_steps.transformations.image_slicer.v1

Classes

Functions

generate_offsets

generate_offsets(resolution_wh, slice_wh, overlap_ratio_wh)

Original code: https://github.com/roboflow/supervision/blob/5123085037ec594524fc8f9d9b71b1cd9f487e8d/supervision/detection/tools/inference_slicer.py#L204-L203 to avoid fragile contract with supervision, as this function is not element of public API.

Generate offset coordinates for slicing an image based on the given resolution, slice dimensions, and overlap ratios.

Parameters:

Name Type Description Default
resolution_wh Tuple[int, int]

A tuple representing the width and height of the image to be sliced.

required
slice_wh Tuple[int, int]

Dimensions of each slice measured in pixels. The

required
overlap_ratio_wh Optional[Tuple[float, float]]

A tuple representing the desired overlap ratio for width and height between consecutive slices. Each value should be in the range [0, 1), where 0 means no overlap and a value close to 1 means high overlap.

required
Note

The function ensures that slices do not exceed the boundaries of the original image. As a result, the final slices in the row and column dimensions might be smaller than the specified slice dimensions if the image's width or height is not a multiple of the slice's width or height minus the overlap.

Source code in inference/core/workflows/core_steps/transformations/image_slicer/v1.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def generate_offsets(
    resolution_wh: Tuple[int, int],
    slice_wh: Tuple[int, int],
    overlap_ratio_wh: Optional[Tuple[float, float]],
) -> np.ndarray:
    """
    Original code: https://github.com/roboflow/supervision/blob/5123085037ec594524fc8f9d9b71b1cd9f487e8d/supervision/detection/tools/inference_slicer.py#L204-L203
    to avoid fragile contract with supervision, as this function is not element of public
    API.

    Generate offset coordinates for slicing an image based on the given resolution,
    slice dimensions, and overlap ratios.

    Args:
        resolution_wh (Tuple[int, int]): A tuple representing the width and height
            of the image to be sliced.
        slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The
        tuple should be in the format `(width, height)`.
        overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
            desired overlap ratio for width and height between consecutive slices.
            Each value should be in the range [0, 1), where 0 means no overlap and
            a value close to 1 means high overlap.
    Returns:
        np.ndarray: An array of shape `(n, 4)` containing coordinates for each
            slice in the format `[xmin, ymin, xmax, ymax]`.

    Note:
        The function ensures that slices do not exceed the boundaries of the
            original image. As a result, the final slices in the row and column
            dimensions might be smaller than the specified slice dimensions if the
            image's width or height is not a multiple of the slice's width or
            height minus the overlap.
    """
    slice_width, slice_height = slice_wh
    image_width, image_height = resolution_wh
    overlap_width = int(overlap_ratio_wh[0] * slice_width)
    overlap_height = int(overlap_ratio_wh[1] * slice_height)
    width_stride = slice_width - overlap_width
    height_stride = slice_height - overlap_height
    ws = np.arange(0, image_width, width_stride)
    hs = np.arange(0, image_height, height_stride)
    xmin, ymin = np.meshgrid(ws, hs)
    xmax = np.clip(xmin + slice_width, 0, image_width)
    ymax = np.clip(ymin + slice_height, 0, image_height)
    return np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4)

inference.core.workflows.core_steps.transformations.image_slicer.v2

Classes

Functions

generate_offsets

generate_offsets(resolution_wh, slice_wh, overlap_ratio_wh)

This is modification of the function from block v1, which makes sure that the "border" crops are pushed towards the center of the image, making sure: * all crops will be the same size * deduplication of crops coordinates is done

Source code in inference/core/workflows/core_steps/transformations/image_slicer/v2.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def generate_offsets(
    resolution_wh: Tuple[int, int],
    slice_wh: Tuple[int, int],
    overlap_ratio_wh: Tuple[float, float],
) -> np.ndarray:
    """
    This is modification of the function from block v1, which
    makes sure that the "border" crops are pushed towards the center of
    the image, making sure:
        * all crops will be the same size
        * deduplication of crops coordinates is done
    """
    slice_width, slice_height = slice_wh
    image_width, image_height = resolution_wh
    slice_width = min(slice_width, image_width)
    slice_height = min(slice_height, image_height)
    overlap_width = int(overlap_ratio_wh[0] * slice_width)
    overlap_height = int(overlap_ratio_wh[1] * slice_height)
    width_stride = slice_width - overlap_width
    height_stride = slice_height - overlap_height
    ws = np.arange(0, image_width, width_stride)
    ws_left_over = np.clip(ws + slice_width - image_width, 0, slice_width)
    hs = np.arange(0, image_height, height_stride)
    hs_left_over = np.clip(hs + slice_height - image_height, 0, slice_height)
    anchors_ws = ws - ws_left_over
    anchors_hs = hs - hs_left_over
    xmin, ymin = np.meshgrid(anchors_ws, anchors_hs)
    xmax = np.clip(xmin + slice_width, 0, image_width)
    ymax = np.clip(ymin + slice_height, 0, image_height)
    results = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4)
    deduplicated_results = []
    already_seen = set()
    for xyxy in results:
        xyxy_tuple = tuple(xyxy)
        if xyxy_tuple in already_seen:
            continue
        deduplicated_results.append(xyxy)
        already_seen.add(xyxy_tuple)
    return np.array(deduplicated_results)

core/workflows/core_steps/transformations/qr_code_generator

inference.core.workflows.core_steps.transformations.qr_code_generator.v1

Classes

Functions

generate_qr_code

generate_qr_code(
    text,
    version=None,
    box_size=10,
    error_correct="M",
    border=4,
    fill_color="BLACK",
    back_color="WHITE",
)

Generate a QR code PNG image from text input.

Source code in inference/core/workflows/core_steps/transformations/qr_code_generator/v1.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def generate_qr_code(
    text: str,
    version: Optional[int] = None,
    box_size: int = 10,
    error_correct: str = "M",
    border: int = 4,
    fill_color: str = "BLACK",
    back_color: str = "WHITE",
) -> WorkflowImageData:
    """Generate a QR code PNG image from text input."""
    global _ERROR_LEVELS, _QR_CACHE

    # Check cache first
    cached_result = _QR_CACHE.get(
        text, version, box_size, error_correct, border, fill_color, back_color
    )
    if cached_result is not None:
        return cached_result

    try:
        import qrcode
    except ImportError:
        raise ImportError(
            "qrcode library is required for QR code generation. "
            "Install it with: pip install qrcode"
        )
    if _ERROR_LEVELS is None:
        _ERROR_LEVELS = _get_error_levels()

    # Parse colors using the common utility that handles hex, rgb, bgr, and standard names
    try:
        # Convert to supervision Color object, then to RGB tuple for qrcode library
        fill_sv_color = str_to_color(fill_color)
        fill = fill_sv_color.as_rgb()  # Returns (R, G, B) tuple
    except (ValueError, AttributeError):
        # Fallback to original string if not a recognized format
        # This allows qrcode library to handle CSS3 color names directly
        fill = fill_color

    try:
        back_sv_color = str_to_color(back_color)
        back = back_sv_color.as_rgb()  # Returns (R, G, B) tuple
    except (ValueError, AttributeError):
        # Fallback to original string if not a recognized format
        back = back_color

    error_level = _ERROR_LEVELS.get(
        error_correct.upper(), qrcode.constants.ERROR_CORRECT_M
    )

    # Create QR code
    qr = qrcode.QRCode(
        version=version,
        error_correction=error_level,
        box_size=box_size,
        border=border,
    )

    qr.add_data(text)
    qr.make(fit=(version is None))

    # Generate image using default image factory
    img = qr.make_image(
        fill_color=fill,
        back_color=back,
    ).convert(
        "RGB"
    )  # Ensure always RGB

    # Direct conversion from PIL.Image to numpy array (much faster than encode/decode)
    numpy_image = np.array(img)

    # Convert from RGB (PIL format) to BGR (OpenCV/WorkflowImageData format)
    # PIL creates RGB images, but WorkflowImageData expects BGR format
    numpy_image = numpy_image[:, :, ::-1]  # RGB -> BGR

    # Defensive: numpy_image should never be None; original code checks for None on OpenCV decode failure
    if numpy_image is None or numpy_image.size == 0:
        raise ValueError("Failed to generate QR code image")

    # Create WorkflowImageData
    parent_metadata = ImageParentMetadata(parent_id=f"qr_code.{uuid4()}")
    result = WorkflowImageData(
        parent_metadata=parent_metadata,
        numpy_image=numpy_image,
    )

    # Store in cache
    _QR_CACHE.put(
        text, version, box_size, error_correct, border, fill_color, back_color, result
    )

    return result

core/workflows/core_steps/transformations/stitch_ocr_detections

inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v1

Functions

get_line_separator

get_line_separator(reading_direction)

Get the appropriate separator based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py
341
342
343
def get_line_separator(reading_direction: str) -> str:
    """Get the appropriate separator based on reading direction."""
    return "\n" if reading_direction in ["left_to_right", "right_to_left"] else " "

group_detections_by_line

group_detections_by_line(
    xyxy, reading_direction, tolerance
)

Group detections into lines based on primary coordinate.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def group_detections_by_line(
    xyxy: np.ndarray,
    reading_direction: str,
    tolerance: int,
) -> Dict[float, Dict[str, List]]:
    """Group detections into lines based on primary coordinate."""
    # After prepare_coordinates swap, we always group by y ([:, 1])
    primary_coord = xyxy[:, 1]  # This is y for horizontal, swapped x for vertical

    # Round primary coordinate to group into lines
    rounded_primary = np.round(primary_coord / tolerance) * tolerance

    boxes_by_line = {}
    # Group bounding boxes and associated indices by line
    for i, (bbox, line_pos) in enumerate(zip(xyxy, rounded_primary)):
        if line_pos not in boxes_by_line:
            boxes_by_line[line_pos] = {"xyxy": [bbox], "idx": [i]}
        else:
            boxes_by_line[line_pos]["xyxy"].append(bbox)
            boxes_by_line[line_pos]["idx"].append(i)

    return boxes_by_line

prepare_coordinates

prepare_coordinates(xyxy, reading_direction)

Prepare coordinates based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py
294
295
296
297
298
299
300
301
302
def prepare_coordinates(
    xyxy: np.ndarray,
    reading_direction: str,
) -> np.ndarray:
    """Prepare coordinates based on reading direction."""
    if reading_direction in ["vertical_top_to_bottom", "vertical_bottom_to_top"]:
        # Swap x and y coordinates: [x1,y1,x2,y2] -> [y1,x1,y2,x2]
        return xyxy[:, [1, 0, 3, 2]]
    return xyxy

sort_line_detections

sort_line_detections(line_xyxy, reading_direction)

Sort detections within a line based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py
329
330
331
332
333
334
335
336
337
338
def sort_line_detections(
    line_xyxy: np.ndarray,
    reading_direction: str,
) -> np.ndarray:
    """Sort detections within a line based on reading direction."""
    # After prepare_coordinates swap, we always sort by x ([:, 0])
    if reading_direction in ["left_to_right", "vertical_top_to_bottom"]:
        return line_xyxy[:, 0].argsort()  # Sort by x1 (original x or swapped y)
    else:  # right_to_left or vertical_bottom_to_top
        return (-line_xyxy[:, 0]).argsort()  # Sort by -x1 (original -x or swapped -y)

stitch_ocr_detections

stitch_ocr_detections(
    detections,
    reading_direction="left_to_right",
    tolerance=10,
    delimiter="",
)

Stitch OCR detections into coherent text based on spatial arrangement.

Parameters:

Name Type Description Default
detections Detections

Supervision Detections object containing OCR results

required
reading_direction str

Direction to read text ("left_to_right", "right_to_left", "vertical_top_to_bottom", "vertical_bottom_to_top")

'left_to_right'
tolerance int

Vertical tolerance for grouping text into lines

10

Returns:

Type Description
Dict[str, str]

Dict containing stitched OCR text under 'ocr_text' key

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def stitch_ocr_detections(
    detections: sv.Detections,
    reading_direction: str = "left_to_right",
    tolerance: int = 10,
    delimiter: str = "",
) -> Dict[str, str]:
    """
    Stitch OCR detections into coherent text based on spatial arrangement.

    Args:
        detections: Supervision Detections object containing OCR results
        reading_direction: Direction to read text ("left_to_right", "right_to_left",
                         "vertical_top_to_bottom", "vertical_bottom_to_top")
        tolerance: Vertical tolerance for grouping text into lines

    Returns:
        Dict containing stitched OCR text under 'ocr_text' key
    """
    if len(detections) == 0:
        return {"ocr_text": ""}

    xyxy = detections.xyxy.round().astype(dtype=int)
    class_names = detections.data["class_name"]

    # Prepare coordinates based on reading direction
    xyxy = prepare_coordinates(xyxy, reading_direction)

    # Group detections into lines
    boxes_by_line = group_detections_by_line(xyxy, reading_direction, tolerance)
    # Sort lines based on reading direction
    lines = sorted(
        boxes_by_line.keys(), reverse=reading_direction in ["vertical_bottom_to_top"]
    )

    # Build final text
    ordered_class_names = []
    for i, key in enumerate(lines):
        line_data = boxes_by_line[key]
        line_xyxy = np.array(line_data["xyxy"])
        line_idx = np.array(line_data["idx"])

        # Sort detections within line
        sort_idx = sort_line_detections(line_xyxy, reading_direction)

        # Add sorted class names for this line
        ordered_class_names.extend(class_names[line_idx[sort_idx]])

        # Add line separator if not last line
        if i < len(lines) - 1:
            ordered_class_names.append(get_line_separator(reading_direction))

    return {"ocr_text": delimiter.join(ordered_class_names)}

inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v2

Classes

CollimateDetection

Helper class for collimate algorithm to store detection properties.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
688
689
690
691
692
693
694
695
696
697
698
699
700
class CollimateDetection:
    """Helper class for collimate algorithm to store detection properties."""

    def __init__(self, xyxy: np.ndarray, class_name: str, idx: int):
        self.x = (xyxy[0] + xyxy[2]) / 2
        self.y = (xyxy[1] + xyxy[3]) / 2
        self.width = xyxy[2] - xyxy[0]
        self.height = xyxy[3] - xyxy[1]
        self.class_name = class_name
        self.idx = idx  # Original index for tracking

    def __repr__(self) -> str:
        return f"{self.class_name}"

StitchingAlgorithm

Bases: str, Enum

Algorithm for grouping detections into words/lines.

Uses fixed pixel tolerance for line grouping (original algorithm).

Good for consistent font sizes and line spacing.

Uses Otsu's method on normalized gaps to find natural breaks.

Resolution-invariant and works well with bimodal distributions (e.g., character-level vs word-level spacing).

Uses greedy parent-child traversal to group detections.

Good for skewed or curved text where bucket-based approaches fail.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class StitchingAlgorithm(str, Enum):
    """Algorithm for grouping detections into words/lines.

    TOLERANCE: Uses fixed pixel tolerance for line grouping (original algorithm).
        Good for consistent font sizes and line spacing.

    OTSU: Uses Otsu's method on normalized gaps to find natural breaks.
        Resolution-invariant and works well with bimodal distributions
        (e.g., character-level vs word-level spacing).

    COLLIMATE: Uses greedy parent-child traversal to group detections.
        Good for skewed or curved text where bucket-based approaches fail.
    """

    TOLERANCE = "tolerance"
    OTSU = "otsu"
    COLLIMATE = "collimate"

Functions

adaptive_word_grouping

adaptive_word_grouping(
    detections,
    reading_direction,
    delimiter="",
    threshold_multiplier=1.0,
)

Stitch OCR detections using adaptive gap analysis with Otsu thresholding.

This approach is resolution-invariant because it normalizes gaps by local character dimensions. It works well with bimodal gap distributions (e.g., character-level vs word-level spacing).

The algorithm computes a global threshold across all lines to leverage the full dataset of gaps, which provides more robust Otsu thresholding than per-line computation.

Parameters:

Name Type Description Default
detections Detections

Supervision Detections object containing OCR results

required
reading_direction str

Direction to read text

required
delimiter str

String to insert between text elements

''
threshold_multiplier float

Multiplier applied to Otsu threshold (>1.0 = fewer word breaks, <1.0 = more word breaks)

1.0

Returns:

Type Description
Dict[str, str]

Dict containing stitched OCR text under 'ocr_text' key

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
def adaptive_word_grouping(
    detections: sv.Detections,
    reading_direction: str,
    delimiter: str = "",
    threshold_multiplier: float = 1.0,
) -> Dict[str, str]:
    """Stitch OCR detections using adaptive gap analysis with Otsu thresholding.

    This approach is resolution-invariant because it normalizes gaps by local
    character dimensions. It works well with bimodal gap distributions
    (e.g., character-level vs word-level spacing).

    The algorithm computes a global threshold across all lines to leverage
    the full dataset of gaps, which provides more robust Otsu thresholding
    than per-line computation.

    Args:
        detections: Supervision Detections object containing OCR results
        reading_direction: Direction to read text
        delimiter: String to insert between text elements
        threshold_multiplier: Multiplier applied to Otsu threshold (>1.0 = fewer word breaks, <1.0 = more word breaks)

    Returns:
        Dict containing stitched OCR text under 'ocr_text' key
    """
    if len(detections) == 0:
        return {"ocr_text": ""}

    xyxy = detections.xyxy
    class_names = detections.data["class_name"]

    # Determine if we're working with vertical text
    is_vertical = reading_direction in [
        "vertical_top_to_bottom",
        "vertical_bottom_to_top",
    ]

    # For vertical text, swap x/y for processing
    if is_vertical:
        # Swap coordinates: treat y as x for sorting
        x_centers = (xyxy[:, 1] + xyxy[:, 3]) / 2  # y becomes primary axis
        y_centers = (xyxy[:, 0] + xyxy[:, 2]) / 2  # x becomes secondary axis
        widths = xyxy[:, 3] - xyxy[:, 1]  # height becomes "width"
        heights = xyxy[:, 2] - xyxy[:, 0]  # width becomes "height"
    else:
        x_centers = (xyxy[:, 0] + xyxy[:, 2]) / 2
        y_centers = (xyxy[:, 1] + xyxy[:, 3]) / 2
        widths = xyxy[:, 2] - xyxy[:, 0]
        heights = xyxy[:, 3] - xyxy[:, 1]

    # First, group detections into lines based on y-coordinate clustering
    # Use adaptive threshold based on median height
    median_height = np.median(heights)
    line_tolerance = median_height * 0.5

    # Sort by y to group into lines
    y_sorted_indices = np.argsort(y_centers)

    lines = []
    current_line = [y_sorted_indices[0]]
    current_line_y = y_centers[y_sorted_indices[0]]

    for idx in y_sorted_indices[1:]:
        if abs(y_centers[idx] - current_line_y) <= line_tolerance:
            current_line.append(idx)
            # Update line y as running average
            current_line_y = np.mean([y_centers[i] for i in current_line])
        else:
            lines.append(current_line)
            current_line = [idx]
            current_line_y = y_centers[idx]
    lines.append(current_line)

    # Sort lines by y position
    line_y_positions = [np.mean([y_centers[i] for i in line]) for line in lines]
    if reading_direction in ["vertical_bottom_to_top"]:
        sorted_line_indices = np.argsort(line_y_positions)[::-1]
    else:
        sorted_line_indices = np.argsort(line_y_positions)

    # First pass: compute normalized gaps for ALL lines to get global threshold
    all_normalized_gaps = []
    line_data = []  # Store sorted line info for second pass

    for line_idx in sorted_line_indices:
        line = lines[line_idx]

        if len(line) == 1:
            line_data.append((line, None, None, None))
            continue

        # Sort detections in line by x position
        line_x_centers = x_centers[line]
        line_widths = widths[line]

        if reading_direction in ["right_to_left", "vertical_bottom_to_top"]:
            x_sorted_order = np.argsort(line_x_centers)[::-1]
        else:
            x_sorted_order = np.argsort(line_x_centers)

        sorted_line = [line[i] for i in x_sorted_order]
        sorted_x_centers = line_x_centers[x_sorted_order]
        sorted_widths = line_widths[x_sorted_order]

        # Compute normalized gaps for this line
        normalized_gaps = []
        for i in range(1, len(sorted_line)):
            prev_idx, curr_idx = i - 1, i
            # Raw gap between detection edges
            if reading_direction in ["right_to_left", "vertical_bottom_to_top"]:
                raw_gap = (
                    sorted_x_centers[prev_idx]
                    - sorted_x_centers[curr_idx]
                    - (sorted_widths[prev_idx] + sorted_widths[curr_idx]) / 2
                )
            else:
                raw_gap = (
                    sorted_x_centers[curr_idx]
                    - sorted_x_centers[prev_idx]
                    - (sorted_widths[prev_idx] + sorted_widths[curr_idx]) / 2
                )

            # Normalize by local character scale
            local_scale = (sorted_widths[prev_idx] + sorted_widths[curr_idx]) / 2
            if local_scale > 0:
                normalized_gaps.append(raw_gap / local_scale)
            else:
                normalized_gaps.append(0.0)

        normalized_gaps = np.array(normalized_gaps)
        all_normalized_gaps.extend(normalized_gaps.tolist())
        line_data.append(
            (sorted_line, sorted_x_centers, sorted_widths, normalized_gaps)
        )

    # Compute global threshold using all gaps, then apply multiplier
    all_normalized_gaps = np.array(all_normalized_gaps)
    global_threshold, is_bimodal = find_otsu_threshold(all_normalized_gaps)
    global_threshold *= threshold_multiplier

    # Second pass: use global threshold to group words
    all_text_parts = []

    for sorted_line, sorted_x_centers, sorted_widths, normalized_gaps in line_data:
        if normalized_gaps is None:
            # Single detection in line
            all_text_parts.append(class_names[sorted_line[0]])
            continue

        # If distribution is not bimodal (likely single word or uniform spacing),
        # treat all detections as a single word to avoid incorrect splitting
        if not is_bimodal:
            word_text = delimiter.join([class_names[idx] for idx in sorted_line])
            all_text_parts.append(word_text)
            continue

        # Group into words based on global threshold
        words = [[sorted_line[0]]]
        for i, det_idx in enumerate(sorted_line[1:]):
            if normalized_gaps[i] > global_threshold:
                words.append([det_idx])
            else:
                words[-1].append(det_idx)

        # Build text for this line
        line_text_parts = []
        for word in words:
            word_text = delimiter.join([class_names[idx] for idx in word])
            line_text_parts.append(word_text)

        # Join words with space (or delimiter if specified and non-empty)
        word_separator = " " if delimiter == "" else delimiter
        all_text_parts.append(word_separator.join(line_text_parts))

    # Join lines with appropriate separator
    line_separator = get_line_separator(reading_direction)
    return {"ocr_text": line_separator.join(all_text_parts)}

collimate_word_grouping

collimate_word_grouping(
    detections,
    reading_direction,
    delimiter="",
    tolerance=10,
)

Stitch OCR detections using greedy parent-child traversal (collimate algorithm).

This algorithm is good for skewed or curved text where traditional bucket-based line grouping may fail. It works by: 1. Sorting detections by primary reading coordinate 2. Starting with the first detection as a "parent" 3. Finding all detections that "follow" the parent (within tolerance) 4. Building lines/columns through greedy traversal

Parameters:

Name Type Description Default
detections Detections

Supervision Detections object containing OCR results

required
reading_direction str

Direction to read text

required
delimiter str

String to insert between characters within words

''
tolerance int

Pixel tolerance for alignment

10

Returns:

Type Description
Dict[str, str]

Dict containing stitched OCR text under 'ocr_text' key

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def collimate_word_grouping(
    detections: sv.Detections,
    reading_direction: str,
    delimiter: str = "",
    tolerance: int = 10,
) -> Dict[str, str]:
    """Stitch OCR detections using greedy parent-child traversal (collimate algorithm).

    This algorithm is good for skewed or curved text where traditional bucket-based
    line grouping may fail. It works by:
    1. Sorting detections by primary reading coordinate
    2. Starting with the first detection as a "parent"
    3. Finding all detections that "follow" the parent (within tolerance)
    4. Building lines/columns through greedy traversal

    Args:
        detections: Supervision Detections object containing OCR results
        reading_direction: Direction to read text
        delimiter: String to insert between characters within words
        tolerance: Pixel tolerance for alignment

    Returns:
        Dict containing stitched OCR text under 'ocr_text' key
    """
    if len(detections) == 0:
        return {"ocr_text": ""}

    xyxy = detections.xyxy
    class_names = detections.data["class_name"]

    # Convert to CollimateDetection objects
    coll_detections = [
        CollimateDetection(xyxy[i], class_names[i], i) for i in range(len(detections))
    ]

    # Sort by primary reading coordinate
    coll_detections = _sort_detections_for_collimate(coll_detections, reading_direction)

    if len(coll_detections) == 0:
        return {"ocr_text": ""}

    # Build lines through greedy parent-child traversal
    remaining = list(coll_detections)
    lines: List[List[CollimateDetection]] = [[remaining.pop(0)]]

    while len(remaining) > 0:
        found_child = False

        # Try to extend existing lines
        for line in lines:
            parent = line[-1]

            # Find children that follow parent
            for det in remaining.copy():
                if _detection_follows(parent, det, reading_direction, tolerance):
                    found_child = True
                    line.append(det)
                    parent = det  # New parent for next iteration
                    remaining.remove(det)

        # If no children found for any line, start a new line
        if not found_child and len(remaining) > 0:
            lines.append([remaining.pop(0)])

    # Sort lines by their average secondary coordinate
    is_vertical = reading_direction in [
        "vertical_top_to_bottom",
        "vertical_bottom_to_top",
    ]
    if is_vertical:
        # For vertical text, sort columns left-to-right (or right-to-left)
        reverse = reading_direction == "vertical_bottom_to_top"
    else:
        # For horizontal text, sort rows top-to-bottom
        reverse = False

    lines = sorted(
        lines,
        key=lambda line: _get_line_avg_coord(line, reading_direction),
        reverse=reverse,
    )

    # Build output text
    line_texts = []
    for line in lines:
        # Characters within a line are concatenated with delimiter
        line_text = delimiter.join(d.class_name for d in line)
        line_texts.append(line_text)

    # Join lines with appropriate separator
    line_separator = get_line_separator(reading_direction)
    return {"ocr_text": line_separator.join(line_texts)}

find_otsu_threshold

find_otsu_threshold(gaps)

Find natural break between intra-word and inter-word gaps using Otsu's method.

This is a resolution-invariant approach that finds the optimal threshold to separate two classes of gaps (e.g., gaps within words vs gaps between words).

Also detects whether the distribution is bimodal (two distinct groups) or unimodal (single group, suggesting single word or uniform spacing).

Parameters:

Name Type Description Default
gaps ndarray

Array of normalized gap values

required

Returns:

Type Description
float

Tuple of (threshold, is_bimodal):

bool
  • threshold: Optimal threshold value that maximizes between-class variance
tuple[float, bool]
  • is_bimodal: True if distribution appears bimodal, False if unimodal
Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def find_otsu_threshold(gaps: np.ndarray) -> tuple[float, bool]:
    """Find natural break between intra-word and inter-word gaps using Otsu's method.

    This is a resolution-invariant approach that finds the optimal threshold
    to separate two classes of gaps (e.g., gaps within words vs gaps between words).

    Also detects whether the distribution is bimodal (two distinct groups) or
    unimodal (single group, suggesting single word or uniform spacing).

    Args:
        gaps: Array of normalized gap values

    Returns:
        Tuple of (threshold, is_bimodal):
        - threshold: Optimal threshold value that maximizes between-class variance
        - is_bimodal: True if distribution appears bimodal, False if unimodal
    """
    if len(gaps) < 2:
        return 0.0, False

    # Create histogram of gaps
    hist, bin_edges = np.histogram(gaps, bins=min(50, len(gaps)))
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    best_thresh = 0.0
    best_variance = 0.0
    best_below_mean = 0.0
    best_above_mean = 0.0

    for t in bin_centers:
        below = gaps[gaps <= t]
        above = gaps[gaps > t]

        if len(below) == 0 or len(above) == 0:
            continue

        # Between-class variance (Otsu's criterion)
        variance = len(below) * len(above) * (below.mean() - above.mean()) ** 2

        if variance > best_variance:
            best_variance = variance
            best_thresh = t
            best_below_mean = below.mean()
            best_above_mean = above.mean()

    # Check if distribution is bimodal using several heuristics:
    # 1. The gap between class means should be significant relative to overall spread
    # 2. There should be meaningful absolute separation between classes

    overall_std = gaps.std()
    overall_mean = gaps.mean()

    # Separation ratio: how far apart are the two class means relative to overall std
    mean_separation = abs(best_above_mean - best_below_mean)
    separation_ratio = mean_separation / overall_std if overall_std > 0 else 0

    # Bimodality criteria - MUST have meaningful word gaps (not just outliers):
    # The key insight is that real word gaps are typically 0.5+ in normalized units.
    # A distribution with all gaps < 0.3 is unimodal (single word), even if there
    # are outliers (like overlapping characters with negative gaps) that inflate
    # the mean separation.
    #
    # Primary criterion: above-class mean must indicate actual word gaps exist
    has_positive_word_gaps = (
        best_above_mean > 0.3
    )  # Word gaps should be clearly positive

    # Secondary criterion: if we have good separation AND positive gaps
    has_good_relative_separation = separation_ratio > 1.5 and mean_separation > 0.3

    # Must have positive word gaps to be considered bimodal
    is_bimodal = has_positive_word_gaps and (
        mean_separation > 0.3 or has_good_relative_separation
    )

    return best_thresh, is_bimodal

get_line_separator

get_line_separator(reading_direction)

Get the appropriate separator based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
426
427
428
def get_line_separator(reading_direction: str) -> str:
    """Get the appropriate separator based on reading direction."""
    return "\n" if reading_direction in ["left_to_right", "right_to_left"] else " "

group_detections_by_line

group_detections_by_line(
    xyxy, reading_direction, tolerance
)

Group detections into lines based on primary coordinate.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def group_detections_by_line(
    xyxy: np.ndarray,
    reading_direction: str,
    tolerance: int,
) -> Dict[float, Dict[str, List]]:
    """Group detections into lines based on primary coordinate."""
    # After prepare_coordinates swap, we always group by y ([:, 1])
    primary_coord = xyxy[:, 1]  # This is y for horizontal, swapped x for vertical

    # Round primary coordinate to group into lines
    rounded_primary = np.round(primary_coord / tolerance) * tolerance

    boxes_by_line = {}
    # Group bounding boxes and associated indices by line
    for i, (bbox, line_pos) in enumerate(zip(xyxy, rounded_primary)):
        if line_pos not in boxes_by_line:
            boxes_by_line[line_pos] = {"xyxy": [bbox], "idx": [i]}
        else:
            boxes_by_line[line_pos]["xyxy"].append(bbox)
            boxes_by_line[line_pos]["idx"].append(i)

    return boxes_by_line

prepare_coordinates

prepare_coordinates(xyxy, reading_direction)

Prepare coordinates based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
379
380
381
382
383
384
385
386
387
def prepare_coordinates(
    xyxy: np.ndarray,
    reading_direction: str,
) -> np.ndarray:
    """Prepare coordinates based on reading direction."""
    if reading_direction in ["vertical_top_to_bottom", "vertical_bottom_to_top"]:
        # Swap x and y coordinates: [x1,y1,x2,y2] -> [y1,x1,y2,x2]
        return xyxy[:, [1, 0, 3, 2]]
    return xyxy

sort_line_detections

sort_line_detections(line_xyxy, reading_direction)

Sort detections within a line based on reading direction.

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
414
415
416
417
418
419
420
421
422
423
def sort_line_detections(
    line_xyxy: np.ndarray,
    reading_direction: str,
) -> np.ndarray:
    """Sort detections within a line based on reading direction."""
    # After prepare_coordinates swap, we always sort by x ([:, 0])
    if reading_direction in ["left_to_right", "vertical_top_to_bottom"]:
        return line_xyxy[:, 0].argsort()  # Sort by x1 (original x or swapped y)
    else:  # right_to_left or vertical_bottom_to_top
        return (-line_xyxy[:, 0]).argsort()  # Sort by -x1 (original -x or swapped -y)

stitch_ocr_detections

stitch_ocr_detections(
    detections,
    reading_direction="left_to_right",
    tolerance=10,
    delimiter="",
)

Stitch OCR detections into coherent text based on spatial arrangement.

Parameters:

Name Type Description Default
detections Detections

Supervision Detections object containing OCR results

required
reading_direction str

Direction to read text ("left_to_right", "right_to_left", "vertical_top_to_bottom", "vertical_bottom_to_top")

'left_to_right'
tolerance int

Vertical tolerance for grouping text into lines

10

Returns:

Type Description
Dict[str, str]

Dict containing stitched OCR text under 'ocr_text' key

Source code in inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v2.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def stitch_ocr_detections(
    detections: sv.Detections,
    reading_direction: str = "left_to_right",
    tolerance: int = 10,
    delimiter: str = "",
) -> Dict[str, str]:
    """
    Stitch OCR detections into coherent text based on spatial arrangement.

    Args:
        detections: Supervision Detections object containing OCR results
        reading_direction: Direction to read text ("left_to_right", "right_to_left",
                         "vertical_top_to_bottom", "vertical_bottom_to_top")
        tolerance: Vertical tolerance for grouping text into lines

    Returns:
        Dict containing stitched OCR text under 'ocr_text' key
    """
    if len(detections) == 0:
        return {"ocr_text": ""}

    xyxy = detections.xyxy.round().astype(dtype=int)
    class_names = detections.data["class_name"]

    # Prepare coordinates based on reading direction
    xyxy = prepare_coordinates(xyxy, reading_direction)

    # Group detections into lines
    boxes_by_line = group_detections_by_line(xyxy, reading_direction, tolerance)
    # Sort lines based on reading direction
    lines = sorted(
        boxes_by_line.keys(), reverse=reading_direction in ["vertical_bottom_to_top"]
    )

    # Build final text
    ordered_class_names = []
    for i, key in enumerate(lines):
        line_data = boxes_by_line[key]
        line_xyxy = np.array(line_data["xyxy"])
        line_idx = np.array(line_data["idx"])

        # Sort detections within line
        sort_idx = sort_line_detections(line_xyxy, reading_direction)

        # Add sorted class names for this line
        ordered_class_names.extend(class_names[line_idx[sort_idx]])

        # Add line separator if not last line
        if i < len(lines) - 1:
            ordered_class_names.append(get_line_separator(reading_direction))

    return {"ocr_text": delimiter.join(ordered_class_names)}

core/workflows/core_steps/visualizations/classification_label

inference.core.workflows.core_steps.visualizations.classification_label.v1

Classes

Functions

create_label_visualization

create_label_visualization(
    sorted_predictions,
    text_position,
    text,
    w,
    h,
    initial_offset,
    total_spacing,
    text_scale,
    text_padding,
)

Create visualization layout for classification labels.

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def create_label_visualization(
    sorted_predictions: List[dict],
    text_position: str,
    text: str,
    w: int,
    h: int,
    initial_offset: float,
    total_spacing: float,
    text_scale: float,
    text_padding: int,
) -> Tuple[np.ndarray, List[str], List[dict]]:
    """Create visualization layout for classification labels."""
    if text_position in ["BOTTOM_LEFT", "BOTTOM_CENTER", "BOTTOM_RIGHT"]:
        return handle_bottom_position(
            sorted_predictions, text, w, h, initial_offset, total_spacing
        )
    elif text_position in ["CENTER", "CENTER_LEFT", "CENTER_RIGHT"]:
        return handle_center_position(
            sorted_predictions,
            text,
            text_position,
            w,
            h,
            total_spacing,
            text_scale,
            text_padding,
        )
    else:  # Top positions
        return handle_top_position(
            sorted_predictions, text, w, h, initial_offset, total_spacing
        )

detect_prediction_type

detect_prediction_type(predictions)

Detect whether predictions are single-label or multi-label based on structure.

Parameters:

Name Type Description Default
predictions dict

The predictions dictionary

required

Returns:

Name Type Description
str str

'single-label' or 'multi-label'

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
478
479
480
481
482
483
484
485
486
487
488
489
490
def detect_prediction_type(predictions: dict) -> str:
    """
    Detect whether predictions are single-label or multi-label based on structure.

    Args:
        predictions (dict): The predictions dictionary

    Returns:
        str: 'single-label' or 'multi-label'
    """
    if isinstance(predictions.get("predictions"), list):
        return "single-label"
    return "multi-label"

format_labels

format_labels(predictions, text='Class and Confidence')

Format labels based on specified text option.

Parameters:

Name Type Description Default
predictions list

List of prediction dictionaries containing 'class' and 'confidence'

required
text str

One of "class", "confidence", or "class and confidence"

'Class and Confidence'

Returns:

Name Type Description
list

Formatted label strings

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def format_labels(predictions, text="Class and Confidence"):
    """
    Format labels based on specified text option.

    Args:
        predictions (list): List of prediction dictionaries containing 'class' and 'confidence'
        text (str): One of "class", "confidence", or "class and confidence"

    Returns:
        list: Formatted label strings
    """
    if text == "Class":
        labels = [f"{p['class']}" for p in predictions]
    elif text == "Confidence":
        labels = [f"{p['confidence']:.2f}" for p in predictions]
    elif text == "Class and Confidence":
        labels = [f"{p['class']} {p['confidence']:.2f}" for p in predictions]
    else:
        raise ValueError(
            "text must be one of: 'class', 'confidence', or 'class and confidence'"
        )

    return labels

format_multi_label_predictions

format_multi_label_predictions(predictions)

Transform multi-label predictions from predicted_classes into standard format.

Parameters:

Name Type Description Default
predictions dict

The predictions dictionary

required

Returns:

Type Description
List[dict]

List[dict]: Formatted predictions list

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def format_multi_label_predictions(predictions: dict) -> List[dict]:
    """
    Transform multi-label predictions from predicted_classes into standard format.

    Args:
        predictions (dict): The predictions dictionary

    Returns:
        List[dict]: Formatted predictions list
    """
    formatted_predictions = []
    for class_name in predictions["predicted_classes"]:
        pred_info = predictions["predictions"][class_name]
        formatted_predictions.append(
            {
                "class": class_name,
                "class_id": pred_info["class_id"],
                "confidence": pred_info["confidence"],
            }
        )
    return formatted_predictions

handle_bottom_position

handle_bottom_position(
    sorted_predictions,
    text,
    w,
    h,
    initial_offset,
    total_spacing,
)

Handle visualization layout for bottom positions.

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def handle_bottom_position(
    sorted_predictions: List[dict],
    text: str,
    w: int,
    h: int,
    initial_offset: float,
    total_spacing: float,
) -> Tuple[np.ndarray, List[str], List[dict]]:
    """Handle visualization layout for bottom positions."""
    reversed_predictions = sorted_predictions[::-1]
    xyxy = np.array(
        [
            [0, 0, w, h - (initial_offset + i * total_spacing)]
            for i in range(len(reversed_predictions))
        ]
    )
    labels = format_labels(reversed_predictions, text)
    return xyxy, labels, reversed_predictions

handle_center_position

handle_center_position(
    sorted_predictions,
    text,
    text_position,
    w,
    h,
    total_spacing,
    text_scale,
    text_padding,
)

Handle visualization layout for center positions.

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def handle_center_position(
    sorted_predictions: List[dict],
    text: str,
    text_position: str,
    w: int,
    h: int,
    total_spacing: float,
    text_scale: float,
    text_padding: int,
) -> Tuple[np.ndarray, List[str], List[dict]]:
    """Handle visualization layout for center positions."""
    labels = format_labels(sorted_predictions, text)
    n_predictions = len(sorted_predictions)
    total_height = total_spacing * n_predictions
    start_y = max(0, min((h - total_height) / 2, h - total_height))

    max_label_length = max(len(label) for label in labels)
    char_width = 15
    label_width = (max_label_length * char_width * text_scale) + (text_padding * 2)
    extra_padding = 20 + max(0, 10 - text_padding) * 3

    if text_position == "CENTER_LEFT":
        x_start = label_width + extra_padding
        xyxy = np.array(
            [
                [
                    x_start,
                    start_y + i * total_spacing,
                    w,
                    start_y + (i + 1) * total_spacing,
                ]
                for i in range(n_predictions)
            ]
        )
    elif text_position == "CENTER_RIGHT":
        x_end = w - (label_width + extra_padding)
        xyxy = np.array(
            [
                [
                    0,
                    start_y + i * total_spacing,
                    x_end,
                    start_y + (i + 1) * total_spacing,
                ]
                for i in range(n_predictions)
            ]
        )
    else:  # CENTER
        xyxy = np.array(
            [
                [0, start_y + i * total_spacing, w, start_y + (i + 1) * total_spacing]
                for i in range(n_predictions)
            ]
        )

    return xyxy, labels, sorted_predictions

handle_top_position

handle_top_position(
    sorted_predictions,
    text,
    w,
    h,
    initial_offset,
    total_spacing,
)

Handle visualization layout for top positions.

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def handle_top_position(
    sorted_predictions: List[dict],
    text: str,
    w: int,
    h: int,
    initial_offset: float,
    total_spacing: float,
) -> Tuple[np.ndarray, List[str], List[dict]]:
    """Handle visualization layout for top positions."""
    xyxy = np.array(
        [
            [0, initial_offset + i * total_spacing, w, h]
            for i in range(len(sorted_predictions))
        ]
    )
    labels = format_labels(sorted_predictions, text)
    return xyxy, labels, sorted_predictions

validate_prediction_format

validate_prediction_format(predictions, task_type)

Validate that the predictions format matches the specified task type.

Parameters:

Name Type Description Default
predictions dict

The predictions dictionary

required
task_type str

The specified task type ('single-label' or 'multi-label')

required

Raises:

Type Description
ValueError

If prediction format doesn't match task type

Source code in inference/core/workflows/core_steps/visualizations/classification_label/v1.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def validate_prediction_format(predictions: dict, task_type: str) -> None:
    """
    Validate that the predictions format matches the specified task type.

    Args:
        predictions (dict): The predictions dictionary
        task_type (str): The specified task type ('single-label' or 'multi-label')

    Raises:
        ValueError: If prediction format doesn't match task type
    """
    actual_type = detect_prediction_type(predictions)

    if actual_type != task_type:
        if actual_type == "single-label":
            raise ValueError(
                "Received single-label predictions but task_type is set to 'multi-label'. Please correct the task_type setting."
            )
        else:
            raise ValueError(
                "Received multi-label predictions but task_type is set to 'single-label'. Please correct the task_type setting."
            )

core/workflows/core_steps/visualizations/common/annotators

inference.core.workflows.core_steps.visualizations.common.annotators.background_color

Classes

BackgroundColorAnnotator

Bases: BaseAnnotator

A class for drawing background colors outside of detected box or mask regions.

Warning

This annotator uses sv.Detections.mask.

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/background_color.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class BackgroundColorAnnotator(BaseAnnotator):
    """
    A class for drawing background colors outside of detected box or mask regions.
    !!! warning
        This annotator uses `sv.Detections.mask`.
    """

    def __init__(
        self,
        color: Color = Color.BLACK,
        opacity: float = 0.5,
        force_box: bool = False,
    ):
        """
        Args:
            color (Color): The color to use for annotating detections.
            opacity (float): Opacity of the overlay mask. Must be between `0` and `1`.
        """
        self.color: Color = color
        self.opacity = opacity
        self.force_box = force_box

    def annotate(self, scene: np.ndarray, detections: Detections) -> np.ndarray:
        """
        Annotates the given scene with masks based on the provided detections.
        Args:
            scene (ImageType): The image where masks will be drawn.
                `ImageType` is a flexible type, accepting either `numpy.ndarray`
                or `PIL.Image.Image`.
            detections (Detections): Object detections to annotate.
        Returns:
            The annotated image, matching the type of `scene` (`numpy.ndarray`
                or `PIL.Image.Image`)
        Example:
            ```python
            import supervision as sv
            image = ...
            detections = sv.Detections(...)
            background_color_annotator = sv.BackgroundColorAnnotator()
            annotated_frame = background_color_annotator.annotate(
                scene=image.copy(),
                detections=detections
            )
            ```
        ![background-color-annotator-example](https://media.roboflow.com/
        supervision-annotator-examples/background-color-annotator-example-purple.png)
        """

        colored_mask = np.full_like(scene, self.color.as_bgr(), dtype=np.uint8)

        cv2.addWeighted(
            scene, 1 - self.opacity, colored_mask, self.opacity, 0, dst=colored_mask
        )

        if detections.mask is None or self.force_box:
            for detection_idx in range(len(detections)):
                x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
                colored_mask[y1:y2, x1:x2] = scene[y1:y2, x1:x2]
        else:
            for mask in detections.mask:
                colored_mask[mask] = scene[mask]

        return colored_mask
Functions
__init__
__init__(color=Color.BLACK, opacity=0.5, force_box=False)

Parameters:

Name Type Description Default
color Color

The color to use for annotating detections.

BLACK
opacity float

Opacity of the overlay mask. Must be between 0 and 1.

0.5
Source code in inference/core/workflows/core_steps/visualizations/common/annotators/background_color.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def __init__(
    self,
    color: Color = Color.BLACK,
    opacity: float = 0.5,
    force_box: bool = False,
):
    """
    Args:
        color (Color): The color to use for annotating detections.
        opacity (float): Opacity of the overlay mask. Must be between `0` and `1`.
    """
    self.color: Color = color
    self.opacity = opacity
    self.force_box = force_box
annotate
annotate(scene, detections)

Annotates the given scene with masks based on the provided detections. Args: scene (ImageType): The image where masks will be drawn. ImageType is a flexible type, accepting either numpy.ndarray or PIL.Image.Image. detections (Detections): Object detections to annotate. Returns: The annotated image, matching the type of scene (numpy.ndarray or PIL.Image.Image) Example:

import supervision as sv
image = ...
detections = sv.Detections(...)
background_color_annotator = sv.BackgroundColorAnnotator()
annotated_frame = background_color_annotator.annotate(
    scene=image.copy(),
    detections=detections
)
background-color-annotator-example

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/background_color.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def annotate(self, scene: np.ndarray, detections: Detections) -> np.ndarray:
    """
    Annotates the given scene with masks based on the provided detections.
    Args:
        scene (ImageType): The image where masks will be drawn.
            `ImageType` is a flexible type, accepting either `numpy.ndarray`
            or `PIL.Image.Image`.
        detections (Detections): Object detections to annotate.
    Returns:
        The annotated image, matching the type of `scene` (`numpy.ndarray`
            or `PIL.Image.Image`)
    Example:
        ```python
        import supervision as sv
        image = ...
        detections = sv.Detections(...)
        background_color_annotator = sv.BackgroundColorAnnotator()
        annotated_frame = background_color_annotator.annotate(
            scene=image.copy(),
            detections=detections
        )
        ```
    ![background-color-annotator-example](https://media.roboflow.com/
    supervision-annotator-examples/background-color-annotator-example-purple.png)
    """

    colored_mask = np.full_like(scene, self.color.as_bgr(), dtype=np.uint8)

    cv2.addWeighted(
        scene, 1 - self.opacity, colored_mask, self.opacity, 0, dst=colored_mask
    )

    if detections.mask is None or self.force_box:
        for detection_idx in range(len(detections)):
            x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
            colored_mask[y1:y2, x1:x2] = scene[y1:y2, x1:x2]
    else:
        for mask in detections.mask:
            colored_mask[mask] = scene[mask]

    return colored_mask

inference.core.workflows.core_steps.visualizations.common.annotators.halo

Classes

HaloAnnotator

Bases: BaseAnnotator

A class for drawing Halos on an image using provided detections.

Warning

This annotator uses sv.Detections.mask.

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/halo.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class HaloAnnotator(BaseAnnotator):
    """
    A class for drawing Halos on an image using provided detections.

    !!! warning

        This annotator uses `sv.Detections.mask`.
    """

    def __init__(
        self,
        color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
        opacity: float = 0.8,
        kernel_size: int = 40,
        color_lookup: ColorLookup = ColorLookup.CLASS,
    ):
        """
        Args:
            color (Union[Color, ColorPalette]): The color or color palette to use for
                annotating detections.
            opacity (float): Opacity of the overlay mask. Must be between `0` and `1`.
            kernel_size (int): The size of the average pooling kernel used for creating
                the halo.
            color_lookup (ColorLookup): Strategy for mapping colors to annotations.
                Options are `INDEX`, `CLASS`, `TRACK`.
        """
        self.color: Union[Color, ColorPalette] = color
        self.opacity = opacity
        self.color_lookup: ColorLookup = color_lookup
        self.kernel_size: int = kernel_size

    @ensure_cv2_image_for_annotation
    def annotate(
        self,
        scene: ImageType,
        detections: Detections,
        custom_color_lookup: Optional[np.ndarray] = None,
    ) -> ImageType:
        """
        Annotates the given scene with halos based on the provided detections.

        Args:
            scene (ImageType): The image where masks will be drawn.
                `ImageType` is a flexible type, accepting either `numpy.ndarray`
                or `PIL.Image.Image`.
            detections (Detections): Object detections to annotate.
            custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
                Allows to override the default color mapping strategy.

        Returns:
            The annotated image, matching the type of `scene` (`numpy.ndarray`
                or `PIL.Image.Image`)

        Example:
            ```python
            import supervision as sv

            image = ...
            detections = sv.Detections(...)

            halo_annotator = sv.HaloAnnotator()
            annotated_frame = halo_annotator.annotate(
                scene=image.copy(),
                detections=detections
            )
            ```

        ![halo-annotator-example](https://media.roboflow.com/
        supervision-annotator-examples/halo-annotator-example-purple.png)
        """
        assert isinstance(scene, np.ndarray)
        colored_mask = np.zeros_like(scene, dtype=np.uint8)
        fmask = np.array([False] * scene.shape[0] * scene.shape[1]).reshape(
            scene.shape[0], scene.shape[1]
        )

        for detection_idx in np.flip(np.argsort(detections.area)):
            color = resolve_color(
                color=self.color,
                detections=detections,
                detection_idx=detection_idx,
                color_lookup=(
                    self.color_lookup
                    if custom_color_lookup is None
                    else custom_color_lookup
                ),
            )
            if detections.mask is None:
                x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
                mask = np.zeros(scene.shape[:2], dtype=bool)
                mask[y1:y2, x1:x2] = True
            else:
                mask = detections.mask[detection_idx]
            fmask = np.logical_or(fmask, mask)
            color_bgr = color.as_bgr()
            colored_mask[mask] = color_bgr

        colored_mask = cv2.blur(colored_mask, (self.kernel_size, self.kernel_size))
        colored_mask[fmask] = [0, 0, 0]
        gray = cv2.cvtColor(colored_mask, cv2.COLOR_BGR2GRAY)
        alpha = self.opacity * gray / gray.max()
        alpha_mask = alpha[:, :, np.newaxis]
        blended_scene = np.uint8(scene * (1 - alpha_mask) + colored_mask * self.opacity)
        np.copyto(scene, blended_scene)
        return scene
Functions
__init__
__init__(
    color=ColorPalette.DEFAULT,
    opacity=0.8,
    kernel_size=40,
    color_lookup=ColorLookup.CLASS,
)

Parameters:

Name Type Description Default
color Union[Color, ColorPalette]

The color or color palette to use for annotating detections.

DEFAULT
opacity float

Opacity of the overlay mask. Must be between 0 and 1.

0.8
kernel_size int

The size of the average pooling kernel used for creating the halo.

40
color_lookup ColorLookup

Strategy for mapping colors to annotations. Options are INDEX, CLASS, TRACK.

CLASS
Source code in inference/core/workflows/core_steps/visualizations/common/annotators/halo.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
    opacity: float = 0.8,
    kernel_size: int = 40,
    color_lookup: ColorLookup = ColorLookup.CLASS,
):
    """
    Args:
        color (Union[Color, ColorPalette]): The color or color palette to use for
            annotating detections.
        opacity (float): Opacity of the overlay mask. Must be between `0` and `1`.
        kernel_size (int): The size of the average pooling kernel used for creating
            the halo.
        color_lookup (ColorLookup): Strategy for mapping colors to annotations.
            Options are `INDEX`, `CLASS`, `TRACK`.
    """
    self.color: Union[Color, ColorPalette] = color
    self.opacity = opacity
    self.color_lookup: ColorLookup = color_lookup
    self.kernel_size: int = kernel_size
annotate
annotate(scene, detections, custom_color_lookup=None)

Annotates the given scene with halos based on the provided detections.

Parameters:

Name Type Description Default
scene ImageType

The image where masks will be drawn. ImageType is a flexible type, accepting either numpy.ndarray or PIL.Image.Image.

required
detections Detections

Object detections to annotate.

required
custom_color_lookup Optional[ndarray]

Custom color lookup array. Allows to override the default color mapping strategy.

None

Returns:

Type Description
ImageType

The annotated image, matching the type of scene (numpy.ndarray or PIL.Image.Image)

Example
import supervision as sv

image = ...
detections = sv.Detections(...)

halo_annotator = sv.HaloAnnotator()
annotated_frame = halo_annotator.annotate(
    scene=image.copy(),
    detections=detections
)

halo-annotator-example

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/halo.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@ensure_cv2_image_for_annotation
def annotate(
    self,
    scene: ImageType,
    detections: Detections,
    custom_color_lookup: Optional[np.ndarray] = None,
) -> ImageType:
    """
    Annotates the given scene with halos based on the provided detections.

    Args:
        scene (ImageType): The image where masks will be drawn.
            `ImageType` is a flexible type, accepting either `numpy.ndarray`
            or `PIL.Image.Image`.
        detections (Detections): Object detections to annotate.
        custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
            Allows to override the default color mapping strategy.

    Returns:
        The annotated image, matching the type of `scene` (`numpy.ndarray`
            or `PIL.Image.Image`)

    Example:
        ```python
        import supervision as sv

        image = ...
        detections = sv.Detections(...)

        halo_annotator = sv.HaloAnnotator()
        annotated_frame = halo_annotator.annotate(
            scene=image.copy(),
            detections=detections
        )
        ```

    ![halo-annotator-example](https://media.roboflow.com/
    supervision-annotator-examples/halo-annotator-example-purple.png)
    """
    assert isinstance(scene, np.ndarray)
    colored_mask = np.zeros_like(scene, dtype=np.uint8)
    fmask = np.array([False] * scene.shape[0] * scene.shape[1]).reshape(
        scene.shape[0], scene.shape[1]
    )

    for detection_idx in np.flip(np.argsort(detections.area)):
        color = resolve_color(
            color=self.color,
            detections=detections,
            detection_idx=detection_idx,
            color_lookup=(
                self.color_lookup
                if custom_color_lookup is None
                else custom_color_lookup
            ),
        )
        if detections.mask is None:
            x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
            mask = np.zeros(scene.shape[:2], dtype=bool)
            mask[y1:y2, x1:x2] = True
        else:
            mask = detections.mask[detection_idx]
        fmask = np.logical_or(fmask, mask)
        color_bgr = color.as_bgr()
        colored_mask[mask] = color_bgr

    colored_mask = cv2.blur(colored_mask, (self.kernel_size, self.kernel_size))
    colored_mask[fmask] = [0, 0, 0]
    gray = cv2.cvtColor(colored_mask, cv2.COLOR_BGR2GRAY)
    alpha = self.opacity * gray / gray.max()
    alpha_mask = alpha[:, :, np.newaxis]
    blended_scene = np.uint8(scene * (1 - alpha_mask) + colored_mask * self.opacity)
    np.copyto(scene, blended_scene)
    return scene

inference.core.workflows.core_steps.visualizations.common.annotators.model_comparison

Classes

ModelComparisonAnnotator

Bases: BaseAnnotator

A class for annotating images by highlighting regions predicted by two different models. This annotator visually distinguishes areas uniquely predicted by each model as well as the background where neither model made a prediction.

Attributes:

Name Type Description
color_a Color

Color used to highlight predictions made only by Model A.

color_b Color

Color used to highlight predictions made only by Model B.

background_color Color

Color used for parts of the image where neither model made a prediction.

opacity float

Opacity level of the overlays, ranging between 0 and 1.

force_box bool

If True, forces the use of bounding boxes for predictions even if masks are available.

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/model_comparison.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class ModelComparisonAnnotator(BaseAnnotator):
    """
    A class for annotating images by highlighting regions predicted by two different models.
    This annotator visually distinguishes areas uniquely predicted by each model as well as
    the background where neither model made a prediction.

    Attributes:
        color_a (Color): Color used to highlight predictions made only by Model A.
        color_b (Color): Color used to highlight predictions made only by Model B.
        background_color (Color): Color used for parts of the image where neither model made a prediction.
        opacity (float): Opacity level of the overlays, ranging between 0 and 1.
        force_box (bool): If True, forces the use of bounding boxes for predictions even if masks are available.
    """

    def __init__(
        self,
        color_a: Color = Color.GREEN,
        color_b: Color = Color.RED,
        background_color: Color = Color.BLACK,
        opacity: float = 0.7,
        force_box: bool = False,
    ):
        """
        Initializes the ModelComparisonAnnotator with the specified colors, opacity, and behavior.

        Args:
            color_a (Color): Color used to highlight predictions made only by Model A.
            color_b (Color): Color used to highlight predictions made only by Model B.
            background_color (Color): Color for parts of the image not covered by any prediction.
            opacity (float): Opacity of the overlay mask, must be between 0 and 1.
            force_box (bool): Whether to use bounding boxes instead of masks if masks are available.
        """
        self.color_a: Color = color_a
        self.color_b: Color = color_b
        self.background_color: Color = background_color
        self.opacity = opacity
        self.force_box = force_box

    def annotate(
        self, scene: np.ndarray, detections_a: Detections, detections_b: Detections
    ) -> np.ndarray:
        """
        Annotates the given scene with highlights representing predictions from two models.

        Args:
            scene (np.ndarray): Original image as a NumPy array (H x W x C).
            detections_a (Detections): Predictions from Model A.
            detections_b (Detections): Predictions from Model B.

        Returns:
            np.ndarray: Annotated image as a NumPy array.
        """

        # Initialize single-channel masks
        neither_predicted = np.ones(
            scene.shape[:2], dtype=np.uint8
        )  # 1 where neither model predicts
        a_predicted = np.zeros(scene.shape[:2], dtype=np.uint8)
        b_predicted = np.zeros(scene.shape[:2], dtype=np.uint8)

        # Populate masks based on detections from Model A
        if detections_a.mask is None or self.force_box:
            for detection_idx in range(len(detections_a)):
                x1, y1, x2, y2 = detections_a.xyxy[detection_idx].astype(int)
                a_predicted[y1:y2, x1:x2] = 1
                neither_predicted[y1:y2, x1:x2] = 0
        else:
            for mask in detections_a.mask:
                a_predicted[mask.astype(bool)] = 1
                neither_predicted[mask.astype(bool)] = 0

        # Populate masks based on detections from Model B
        if detections_b.mask is None or self.force_box:
            for detection_idx in range(len(detections_b)):
                x1, y1, x2, y2 = detections_b.xyxy[detection_idx].astype(int)
                b_predicted[y1:y2, x1:x2] = 1
                neither_predicted[y1:y2, x1:x2] = 0
        else:
            for mask in detections_b.mask:
                b_predicted[mask.astype(bool)] = 1
                neither_predicted[mask.astype(bool)] = 0

        # Define combined masks
        only_a_predicted = a_predicted & (a_predicted ^ b_predicted)
        only_b_predicted = b_predicted & (b_predicted ^ a_predicted)

        # Prepare overlay colors
        background_color_bgr = self.background_color.as_bgr()  # Tuple like (B, G, R)
        color_a_bgr = self.color_a.as_bgr()
        color_b_bgr = self.color_b.as_bgr()

        # Create full-color overlay images
        overlay_background = np.full_like(scene, background_color_bgr, dtype=np.uint8)
        overlay_a = np.full_like(scene, color_a_bgr, dtype=np.uint8)
        overlay_b = np.full_like(scene, color_b_bgr, dtype=np.uint8)

        # Function to blend and apply overlay based on mask
        def apply_overlay(base_img, overlay_img, mask, opacity):
            """
            Blends the overlay with the base image where the mask is set.

            Args:
                base_img (np.ndarray): Original image.
                overlay_img (np.ndarray): Overlay color image.
                mask (np.ndarray): Single-channel mask where to apply the overlay.
                opacity (float): Opacity of the overlay (0 to 1).

            Returns:
                np.ndarray: Image with overlay applied.
            """
            # Blend the entire images
            blended = cv2.addWeighted(base_img, 1 - opacity, overlay_img, opacity, 0)
            # Expand mask to three channels
            mask_3ch = np.stack([mask] * 3, axis=-1)  # Shape: H x W x 3
            # Ensure mask is boolean
            mask_bool = mask_3ch.astype(bool)
            # Apply blended regions where mask is True
            base_img[mask_bool] = blended[mask_bool]
            return base_img

        # Apply background overlay where neither model predicted
        scene = apply_overlay(
            scene, overlay_background, neither_predicted, self.opacity
        )

        # Apply overlay for only Model A predictions
        scene = apply_overlay(scene, overlay_a, only_a_predicted, self.opacity)

        # Apply overlay for only Model B predictions
        scene = apply_overlay(scene, overlay_b, only_b_predicted, self.opacity)

        # Areas where both models predicted remain unchanged (no overlay)

        return scene
Functions
__init__
__init__(
    color_a=Color.GREEN,
    color_b=Color.RED,
    background_color=Color.BLACK,
    opacity=0.7,
    force_box=False,
)

Initializes the ModelComparisonAnnotator with the specified colors, opacity, and behavior.

Parameters:

Name Type Description Default
color_a Color

Color used to highlight predictions made only by Model A.

GREEN
color_b Color

Color used to highlight predictions made only by Model B.

RED
background_color Color

Color for parts of the image not covered by any prediction.

BLACK
opacity float

Opacity of the overlay mask, must be between 0 and 1.

0.7
force_box bool

Whether to use bounding boxes instead of masks if masks are available.

False
Source code in inference/core/workflows/core_steps/visualizations/common/annotators/model_comparison.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    color_a: Color = Color.GREEN,
    color_b: Color = Color.RED,
    background_color: Color = Color.BLACK,
    opacity: float = 0.7,
    force_box: bool = False,
):
    """
    Initializes the ModelComparisonAnnotator with the specified colors, opacity, and behavior.

    Args:
        color_a (Color): Color used to highlight predictions made only by Model A.
        color_b (Color): Color used to highlight predictions made only by Model B.
        background_color (Color): Color for parts of the image not covered by any prediction.
        opacity (float): Opacity of the overlay mask, must be between 0 and 1.
        force_box (bool): Whether to use bounding boxes instead of masks if masks are available.
    """
    self.color_a: Color = color_a
    self.color_b: Color = color_b
    self.background_color: Color = background_color
    self.opacity = opacity
    self.force_box = force_box
annotate
annotate(scene, detections_a, detections_b)

Annotates the given scene with highlights representing predictions from two models.

Parameters:

Name Type Description Default
scene ndarray

Original image as a NumPy array (H x W x C).

required
detections_a Detections

Predictions from Model A.

required
detections_b Detections

Predictions from Model B.

required

Returns:

Type Description
ndarray

np.ndarray: Annotated image as a NumPy array.

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/model_comparison.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def annotate(
    self, scene: np.ndarray, detections_a: Detections, detections_b: Detections
) -> np.ndarray:
    """
    Annotates the given scene with highlights representing predictions from two models.

    Args:
        scene (np.ndarray): Original image as a NumPy array (H x W x C).
        detections_a (Detections): Predictions from Model A.
        detections_b (Detections): Predictions from Model B.

    Returns:
        np.ndarray: Annotated image as a NumPy array.
    """

    # Initialize single-channel masks
    neither_predicted = np.ones(
        scene.shape[:2], dtype=np.uint8
    )  # 1 where neither model predicts
    a_predicted = np.zeros(scene.shape[:2], dtype=np.uint8)
    b_predicted = np.zeros(scene.shape[:2], dtype=np.uint8)

    # Populate masks based on detections from Model A
    if detections_a.mask is None or self.force_box:
        for detection_idx in range(len(detections_a)):
            x1, y1, x2, y2 = detections_a.xyxy[detection_idx].astype(int)
            a_predicted[y1:y2, x1:x2] = 1
            neither_predicted[y1:y2, x1:x2] = 0
    else:
        for mask in detections_a.mask:
            a_predicted[mask.astype(bool)] = 1
            neither_predicted[mask.astype(bool)] = 0

    # Populate masks based on detections from Model B
    if detections_b.mask is None or self.force_box:
        for detection_idx in range(len(detections_b)):
            x1, y1, x2, y2 = detections_b.xyxy[detection_idx].astype(int)
            b_predicted[y1:y2, x1:x2] = 1
            neither_predicted[y1:y2, x1:x2] = 0
    else:
        for mask in detections_b.mask:
            b_predicted[mask.astype(bool)] = 1
            neither_predicted[mask.astype(bool)] = 0

    # Define combined masks
    only_a_predicted = a_predicted & (a_predicted ^ b_predicted)
    only_b_predicted = b_predicted & (b_predicted ^ a_predicted)

    # Prepare overlay colors
    background_color_bgr = self.background_color.as_bgr()  # Tuple like (B, G, R)
    color_a_bgr = self.color_a.as_bgr()
    color_b_bgr = self.color_b.as_bgr()

    # Create full-color overlay images
    overlay_background = np.full_like(scene, background_color_bgr, dtype=np.uint8)
    overlay_a = np.full_like(scene, color_a_bgr, dtype=np.uint8)
    overlay_b = np.full_like(scene, color_b_bgr, dtype=np.uint8)

    # Function to blend and apply overlay based on mask
    def apply_overlay(base_img, overlay_img, mask, opacity):
        """
        Blends the overlay with the base image where the mask is set.

        Args:
            base_img (np.ndarray): Original image.
            overlay_img (np.ndarray): Overlay color image.
            mask (np.ndarray): Single-channel mask where to apply the overlay.
            opacity (float): Opacity of the overlay (0 to 1).

        Returns:
            np.ndarray: Image with overlay applied.
        """
        # Blend the entire images
        blended = cv2.addWeighted(base_img, 1 - opacity, overlay_img, opacity, 0)
        # Expand mask to three channels
        mask_3ch = np.stack([mask] * 3, axis=-1)  # Shape: H x W x 3
        # Ensure mask is boolean
        mask_bool = mask_3ch.astype(bool)
        # Apply blended regions where mask is True
        base_img[mask_bool] = blended[mask_bool]
        return base_img

    # Apply background overlay where neither model predicted
    scene = apply_overlay(
        scene, overlay_background, neither_predicted, self.opacity
    )

    # Apply overlay for only Model A predictions
    scene = apply_overlay(scene, overlay_a, only_a_predicted, self.opacity)

    # Apply overlay for only Model B predictions
    scene = apply_overlay(scene, overlay_b, only_b_predicted, self.opacity)

    # Areas where both models predicted remain unchanged (no overlay)

    return scene

inference.core.workflows.core_steps.visualizations.common.annotators.polygon

Classes

PolygonAnnotator

Bases: BaseAnnotator

A class for drawing polygons on an image using provided detections.

Warning

This annotator uses sv.Detections.mask.

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/polygon.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class PolygonAnnotator(BaseAnnotator):
    """
    A class for drawing polygons on an image using provided detections.

    !!! warning

        This annotator uses `sv.Detections.mask`.
    """

    def __init__(
        self,
        color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
        thickness: int = 2,
        color_lookup: ColorLookup = ColorLookup.CLASS,
    ):
        """
        Args:
            color (Union[Color, ColorPalette]): The color or color palette to use for
                annotating detections.
            thickness (int): Thickness of the polygon lines.
            color_lookup (ColorLookup): Strategy for mapping colors to annotations.
                Options are `INDEX`, `CLASS`, `TRACK`.
        """
        self.color: Union[Color, ColorPalette] = color
        self.thickness: int = thickness
        self.color_lookup: ColorLookup = color_lookup

    @ensure_cv2_image_for_annotation
    def annotate(
        self,
        scene: ImageType,
        detections: Detections,
        custom_color_lookup: Optional[np.ndarray] = None,
    ) -> ImageType:
        """
        Annotates the given scene with polygons based on the provided detections.

        Args:
            scene (ImageType): The image where polygons will be drawn.
                `ImageType` is a flexible type, accepting either `numpy.ndarray`
                or `PIL.Image.Image`.
            detections (Detections): Object detections to annotate.
            custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
                Allows to override the default color mapping strategy.

        Returns:
            The annotated image, matching the type of `scene` (`numpy.ndarray`
                or `PIL.Image.Image`)

        Example:
            ```python
            import supervision as sv

            image = ...
            detections = sv.Detections(...)

            polygon_annotator = sv.PolygonAnnotator()
            annotated_frame = polygon_annotator.annotate(
                scene=image.copy(),
                detections=detections
            )
            ```

        ![polygon-annotator-example](https://media.roboflow.com/
        supervision-annotator-examples/polygon-annotator-example-purple.png)
        """
        assert isinstance(scene, np.ndarray)

        for detection_idx in range(len(detections)):
            color = resolve_color(
                color=self.color,
                detections=detections,
                detection_idx=detection_idx,
                color_lookup=(
                    self.color_lookup
                    if custom_color_lookup is None
                    else custom_color_lookup
                ),
            )

            if detections.mask is None:
                x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
                cv2.rectangle(
                    img=scene,
                    pt1=(x1, y1),
                    pt2=(x2, y2),
                    color=color.as_bgr(),
                    thickness=self.thickness,
                )
            else:
                mask = detections.mask[detection_idx]
                for polygon in mask_to_polygons(mask=mask):
                    scene = draw_polygon(
                        scene=scene,
                        polygon=polygon,
                        color=color,
                        thickness=self.thickness,
                    )

        return scene
Functions
__init__
__init__(
    color=ColorPalette.DEFAULT,
    thickness=2,
    color_lookup=ColorLookup.CLASS,
)

Parameters:

Name Type Description Default
color Union[Color, ColorPalette]

The color or color palette to use for annotating detections.

DEFAULT
thickness int

Thickness of the polygon lines.

2
color_lookup ColorLookup

Strategy for mapping colors to annotations. Options are INDEX, CLASS, TRACK.

CLASS
Source code in inference/core/workflows/core_steps/visualizations/common/annotators/polygon.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
    thickness: int = 2,
    color_lookup: ColorLookup = ColorLookup.CLASS,
):
    """
    Args:
        color (Union[Color, ColorPalette]): The color or color palette to use for
            annotating detections.
        thickness (int): Thickness of the polygon lines.
        color_lookup (ColorLookup): Strategy for mapping colors to annotations.
            Options are `INDEX`, `CLASS`, `TRACK`.
    """
    self.color: Union[Color, ColorPalette] = color
    self.thickness: int = thickness
    self.color_lookup: ColorLookup = color_lookup
annotate
annotate(scene, detections, custom_color_lookup=None)

Annotates the given scene with polygons based on the provided detections.

Parameters:

Name Type Description Default
scene ImageType

The image where polygons will be drawn. ImageType is a flexible type, accepting either numpy.ndarray or PIL.Image.Image.

required
detections Detections

Object detections to annotate.

required
custom_color_lookup Optional[ndarray]

Custom color lookup array. Allows to override the default color mapping strategy.

None

Returns:

Type Description
ImageType

The annotated image, matching the type of scene (numpy.ndarray or PIL.Image.Image)

Example
import supervision as sv

image = ...
detections = sv.Detections(...)

polygon_annotator = sv.PolygonAnnotator()
annotated_frame = polygon_annotator.annotate(
    scene=image.copy(),
    detections=detections
)

polygon-annotator-example

Source code in inference/core/workflows/core_steps/visualizations/common/annotators/polygon.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@ensure_cv2_image_for_annotation
def annotate(
    self,
    scene: ImageType,
    detections: Detections,
    custom_color_lookup: Optional[np.ndarray] = None,
) -> ImageType:
    """
    Annotates the given scene with polygons based on the provided detections.

    Args:
        scene (ImageType): The image where polygons will be drawn.
            `ImageType` is a flexible type, accepting either `numpy.ndarray`
            or `PIL.Image.Image`.
        detections (Detections): Object detections to annotate.
        custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
            Allows to override the default color mapping strategy.

    Returns:
        The annotated image, matching the type of `scene` (`numpy.ndarray`
            or `PIL.Image.Image`)

    Example:
        ```python
        import supervision as sv

        image = ...
        detections = sv.Detections(...)

        polygon_annotator = sv.PolygonAnnotator()
        annotated_frame = polygon_annotator.annotate(
            scene=image.copy(),
            detections=detections
        )
        ```

    ![polygon-annotator-example](https://media.roboflow.com/
    supervision-annotator-examples/polygon-annotator-example-purple.png)
    """
    assert isinstance(scene, np.ndarray)

    for detection_idx in range(len(detections)):
        color = resolve_color(
            color=self.color,
            detections=detections,
            detection_idx=detection_idx,
            color_lookup=(
                self.color_lookup
                if custom_color_lookup is None
                else custom_color_lookup
            ),
        )

        if detections.mask is None:
            x1, y1, x2, y2 = detections.xyxy[detection_idx].astype(int)
            cv2.rectangle(
                img=scene,
                pt1=(x1, y1),
                pt2=(x2, y2),
                color=color.as_bgr(),
                thickness=self.thickness,
            )
        else:
            mask = detections.mask[detection_idx]
            for polygon in mask_to_polygons(mask=mask):
                scene = draw_polygon(
                    scene=scene,
                    polygon=polygon,
                    color=color,
                    thickness=self.thickness,
                )

    return scene

core/workflows/core_steps/visualizations/text_display

inference.core.workflows.core_steps.visualizations.text_display.utils

Functions

align_offset

align_offset(text_align, max_width, line_width)

Calculate horizontal offset for text alignment.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
88
89
90
91
92
93
94
95
def align_offset(text_align: str, max_width: int, line_width: int) -> int:
    """Calculate horizontal offset for text alignment."""
    if text_align == "center":
        return (max_width - line_width) // 2
    elif text_align == "right":
        return max_width - line_width
    else:  # left
        return 0

calculate_relative_position

calculate_relative_position(
    anchor,
    offset_x,
    offset_y,
    box_width,
    box_height,
    img_width,
    img_height,
)

Calculate the top-left corner position for a box positioned relative to an image anchor.

Parameters:

Name Type Description Default
anchor str

Anchor point name (e.g., "top_left", "center", "bottom_right")

required
offset_x int

Horizontal offset from anchor point (positive = right)

required
offset_y int

Vertical offset from anchor point (positive = down)

required
box_width int

Width of the box to position

required
box_height int

Height of the box to position

required
img_width int

Width of the image

required
img_height int

Height of the image

required

Returns:

Type Description
Tuple[int, int]

Tuple of (x, y) coordinates for the top-left corner of the box

Raises:

Type Description
ValueError

If anchor is not recognized

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def calculate_relative_position(
    anchor: str,
    offset_x: int,
    offset_y: int,
    box_width: int,
    box_height: int,
    img_width: int,
    img_height: int,
) -> Tuple[int, int]:
    """Calculate the top-left corner position for a box positioned relative to an image anchor.

    Args:
        anchor: Anchor point name (e.g., "top_left", "center", "bottom_right")
        offset_x: Horizontal offset from anchor point (positive = right)
        offset_y: Vertical offset from anchor point (positive = down)
        box_width: Width of the box to position
        box_height: Height of the box to position
        img_width: Width of the image
        img_height: Height of the image

    Returns:
        Tuple of (x, y) coordinates for the top-left corner of the box

    Raises:
        ValueError: If anchor is not recognized
    """
    key = anchor.lower()
    try:
        ax, ay = ANCHORS[key]
    except KeyError as e:
        raise ValueError(
            f"Unknown anchor: {anchor!r}. Must be one of {sorted(ANCHORS.keys())}"
        ) from e

    anchor_x = int(round(ax * img_width))
    anchor_y = int(round(ay * img_height))

    box_x = anchor_x - int(round(ax * box_width)) + offset_x
    box_y = anchor_y - int(round(ay * box_height)) + offset_y

    return box_x, box_y

clamp_box

clamp_box(box_x, box_y, box_w, box_h, img_w, img_h)

Clamp box position to image bounds.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
79
80
81
82
83
84
85
def clamp_box(
    box_x: int, box_y: int, box_w: int, box_h: int, img_w: int, img_h: int
) -> Tuple[int, int]:
    """Clamp box position to image bounds."""
    box_x = 0 if box_w > img_w else max(0, min(box_x, img_w - box_w))
    box_y = 0 if box_h > img_h else max(0, min(box_y, img_h - box_h))
    return box_x, box_y

compute_layout

compute_layout(
    *,
    formatted_text,
    font,
    font_scale,
    font_thickness,
    padding,
    position_mode,
    position_x,
    position_y,
    anchor,
    offset_x,
    offset_y,
    img_w,
    img_h
)

Compute text layout including dimensions and position.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def compute_layout(
    *,
    formatted_text: str,
    font,
    font_scale: float,
    font_thickness: int,
    padding: int,
    position_mode: str,
    position_x: int,
    position_y: int,
    anchor: str,
    offset_x: int,
    offset_y: int,
    img_w: int,
    img_h: int,
) -> TextLayout:
    """Compute text layout including dimensions and position."""
    lines = formatted_text.split("\n") if formatted_text else [""]
    (_, ref_h), ref_base = cv2.getTextSize("Ag", font, font_scale, font_thickness)
    line_advance = ref_h + ref_base
    line_spacing = max(1, int(round(0.25 * line_advance)))

    line_widths = [
        (
            cv2.getTextSize(line, font, font_scale, font_thickness)[0][0]
            if line.strip()
            else 0
        )
        for line in lines
    ]
    max_width = max(line_widths, default=0)

    num_lines = len(lines)
    total_h = num_lines * line_advance + max(0, num_lines - 1) * line_spacing

    box_w = max_width + 2 * padding
    box_h = total_h + 2 * padding

    if position_mode == "absolute":
        box_x, box_y = position_x, position_y
    else:
        box_x, box_y = calculate_relative_position(
            anchor=anchor,
            offset_x=offset_x,
            offset_y=offset_y,
            box_width=box_w,
            box_height=box_h,
            img_width=img_w,
            img_height=img_h,
        )

    box_x, box_y = clamp_box(box_x, box_y, box_w, box_h, img_w, img_h)

    return TextLayout(
        lines=lines,
        line_widths=line_widths,
        max_width=max_width,
        ref_height=ref_h,
        line_advance=line_advance,
        line_spacing=line_spacing,
        box_x=box_x,
        box_y=box_y,
        box_w=box_w,
        box_h=box_h,
    )

draw_background

draw_background(
    img,
    x1,
    y1,
    x2,
    y2,
    bg_color_bgr,
    background_opacity,
    border_radius,
)

Draw background rectangle with optional transparency and rounded corners.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def draw_background(
    img: np.ndarray,
    x1: int,
    y1: int,
    x2: int,
    y2: int,
    bg_color_bgr: Optional[Tuple[int, int, int]],
    background_opacity: float,
    border_radius: int,
) -> None:
    """Draw background rectangle with optional transparency and rounded corners."""
    if bg_color_bgr is None or x2 <= x1 or y2 <= y1:
        return

    if background_opacity > 0.0:
        if background_opacity < 1.0:
            # Alpha blending required
            draw_background_with_alpha(
                img=img,
                pt1=(x1, y1),
                pt2=(x2, y2),
                color=bg_color_bgr,
                alpha=background_opacity,
                border_radius=border_radius,
            )
        else:
            # Fully opaque - use direct drawing
            # OpenCV uses inclusive coordinates, so subtract 1 from exclusive end coords
            if border_radius > 0:
                draw_rounded_rectangle(
                    img=img,
                    pt1=(x1, y1),
                    pt2=(x2 - 1, y2 - 1),
                    color=bg_color_bgr,
                    radius=border_radius,
                )
            else:
                cv2.rectangle(
                    img,
                    (x1, y1),
                    (x2 - 1, y2 - 1),
                    bg_color_bgr,
                    -1,
                )

draw_background_with_alpha

draw_background_with_alpha(
    img, pt1, pt2, color, alpha, border_radius
)

Draw a filled rectangle with alpha blending using overlay compositing.

Uses proper overlay-based alpha blending for smooth antialiased edges, especially important for rounded rectangles.

Process: 1. Extract the affected region 2. Create overlay and draw shape on it 3. Alpha-blend overlay with original region 4. Write blended result back

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def draw_background_with_alpha(
    img: np.ndarray,
    pt1: Tuple[int, int],
    pt2: Tuple[int, int],
    color: Tuple[int, int, int],
    alpha: float,
    border_radius: int,
) -> None:
    """Draw a filled rectangle with alpha blending using overlay compositing.

    Uses proper overlay-based alpha blending for smooth antialiased edges,
    especially important for rounded rectangles.

    Process:
    1. Extract the affected region
    2. Create overlay and draw shape on it
    3. Alpha-blend overlay with original region
    4. Write blended result back
    """
    x1, y1 = pt1
    x2, y2 = pt2

    # Clamp to image bounds
    img_h, img_w = img.shape[:2]
    x1_clamped = max(0, x1)
    y1_clamped = max(0, y1)
    x2_clamped = min(img_w, x2)
    y2_clamped = min(img_h, y2)

    if x2_clamped <= x1_clamped or y2_clamped <= y1_clamped:
        return

    # Extract the region of interest
    roi = img[y1_clamped:y2_clamped, x1_clamped:x2_clamped]

    # Create overlay for just this region
    overlay = roi.copy()

    roi_w = x2_clamped - x1_clamped
    roi_h = y2_clamped - y1_clamped

    # Draw the shape onto the overlay (coordinates relative
    # to ROI and OpenCV uses inclusive coordinates,
    # so max index is size - 1
    if border_radius > 0:
        draw_rounded_rectangle(
            img=overlay,
            pt1=(0, 0),
            pt2=(roi_w - 1, roi_h - 1),
            color=color,
            radius=border_radius,
        )
    else:
        cv2.rectangle(
            overlay,
            (0, 0),
            (roi_w - 1, roi_h - 1),
            color,
            -1,
        )

    # Alpha blend: result = overlay * alpha + original * (1 - alpha)
    blended = cv2.addWeighted(overlay, alpha, roi, 1 - alpha, 0)

    # Write blended result back to image
    img[y1_clamped:y2_clamped, x1_clamped:x2_clamped] = blended

draw_rounded_rectangle

draw_rounded_rectangle(img, pt1, pt2, color, radius)

Draw a filled rounded rectangle on an image.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def draw_rounded_rectangle(
    img: np.ndarray,
    pt1: Tuple[int, int],
    pt2: Tuple[int, int],
    color: Tuple[int, int, int],
    radius: int,
) -> None:
    """Draw a filled rounded rectangle on an image."""
    x1, y1 = pt1
    x2, y2 = pt2

    # Early return for invalid coordinates
    if x2 <= x1 or y2 <= y1:
        return

    max_radius = min((x2 - x1) // 2, (y2 - y1) // 2)
    radius = min(radius, max_radius)

    if radius <= 0:
        cv2.rectangle(img, pt1, pt2, color, -1)
        return

    cv2.rectangle(img, (x1 + radius, y1), (x2 - radius, y2), color, -1)
    cv2.rectangle(img, (x1, y1 + radius), (x2, y2 - radius), color, -1)

    cv2.ellipse(
        img, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1
    )
    cv2.ellipse(
        img, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1
    )
    cv2.ellipse(img, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1)
    cv2.ellipse(img, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1)

draw_text_lines

draw_text_lines(
    img,
    *,
    layout,
    padding,
    text_align,
    font,
    font_scale,
    font_thickness,
    color_bgr
)

Draw text lines on the image.

Source code in inference/core/workflows/core_steps/visualizations/text_display/utils.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def draw_text_lines(
    img: np.ndarray,
    *,
    layout: TextLayout,
    padding: int,
    text_align: str,
    font,
    font_scale: float,
    font_thickness: int,
    color_bgr: Tuple[int, int, int],
) -> None:
    """Draw text lines on the image."""
    img_h, img_w = img.shape[:2]
    current_y = layout.box_y + padding
    base_x = layout.box_x + padding

    for i, line in enumerate(layout.lines):
        if line.strip():
            w = layout.line_widths[i]
            text_x = base_x + align_offset(text_align, layout.max_width, w)
            text_y = current_y + layout.ref_height

            if text_y > 0 and current_y < img_h and text_x < img_w:
                cv2.putText(
                    img,
                    line,
                    (text_x, text_y),
                    font,
                    font_scale,
                    color_bgr,
                    font_thickness,
                    cv2.LINE_AA,
                )

        current_y += layout.line_advance
        if i < len(layout.lines) - 1:
            current_y += layout.line_spacing

inference.core.workflows.core_steps.visualizations.text_display.v1

Classes

Functions

format_text_with_parameters

format_text_with_parameters(
    text, text_parameters, text_parameters_operations
)

Format text by replacing parameter placeholders with actual values.

Uses a single-pass regex substitution for efficiency and correctness.

Source code in inference/core/workflows/core_steps/visualizations/text_display/v1.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def format_text_with_parameters(
    text: str,
    text_parameters: Dict[str, Any],
    text_parameters_operations: Dict[str, List[AllOperationsType]],
) -> str:
    """Format text by replacing parameter placeholders with actual values.

    Uses a single-pass regex substitution for efficiency and correctness.
    """
    # Cache for computed parameter values (with operations applied)
    computed_values: Dict[str, str] = {}

    def replace_placeholder(match: re.Match) -> str:
        parameter_name = match.group(2)
        if parameter_name not in text_parameters:
            return match.group(0)
        if parameter_name in computed_values:
            return computed_values[parameter_name]

        parameter_value = text_parameters[parameter_name]
        operations = text_parameters_operations.get(parameter_name)
        if operations:
            operations_chain = build_operations_chain(operations=operations)
            parameter_value = operations_chain(parameter_value, global_parameters={})

        # Cache and return
        computed_values[parameter_name] = str(parameter_value)
        return computed_values[parameter_name]

    return PARAMETER_REGEX.sub(replace_placeholder, text)

core/workflows/execution_engine/introspection

inference.core.workflows.execution_engine.introspection.blocks_loader

Functions

clear_caches

clear_caches()

Clear all LRU caches in this module. Useful for testing or when environment configuration changes.

Source code in inference/core/workflows/execution_engine/introspection/blocks_loader.py
62
63
64
65
66
67
68
69
70
71
def clear_caches() -> None:
    """
    Clear all LRU caches in this module.
    Useful for testing or when environment configuration changes.
    """
    _cached_describe_available_blocks.cache_clear()
    load_core_workflow_blocks.cache_clear()
    _cached_load_all_defined_kinds.cache_clear()
    _cached_model_json_schema.cache_clear()
    _cached_describe_outputs.cache_clear()

inference.core.workflows.execution_engine.introspection.schema_parser

Functions

clear_cache

clear_cache()

Clear the parse_block_manifest cache.

Source code in inference/core/workflows/execution_engine/introspection/schema_parser.py
56
57
58
def clear_cache() -> None:
    """Clear the parse_block_manifest cache."""
    parse_block_manifest.cache_clear()

core/workflows/execution_engine/v1/compiler

inference.core.workflows.execution_engine.v1.compiler.cache

Classes

BasicWorkflowsCache

Bases: Generic[V]

Base cache which is capable of hashing compound payloads based on list of injected hash functions. Hash functions are to produce stable hashing strings. Each function is invoked on get_hash_key(...) kwarg (use named args only!), output string is concatenated and md5 value is calculated.

Cache is size bounded, each entry lives until cache_size new entries appear.

Raises WorkflowEnvironmentConfigurationError when get_hash_key(...) is not provided with params corresponding to all hash functions.

Thread safe thanks to thread lock on get(...) and cache(...).

Source code in inference/core/workflows/execution_engine/v1/compiler/cache.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class BasicWorkflowsCache(Generic[V]):
    """
    Base cache which is capable of hashing compound payloads based on
    list of injected hash functions. Hash functions are to produce stable hashing strings.
    Each function is invoked on `get_hash_key(...)` kwarg (use named args only!),
    output string is concatenated and md5 value is calculated.

    Cache is size bounded, each entry lives until `cache_size` new entries appear.

    Raises `WorkflowEnvironmentConfigurationError` when `get_hash_key(...)` is not
    provided with params corresponding to all hash functions.

    Thread safe thanks to thread lock on `get(...)` and `cache(...)`.
    """

    def __init__(
        self,
        cache_size: int,
        hash_functions: List[Tuple[str, Callable[[Any], str]]],
    ):
        self._keys_buffer = deque(maxlen=max(cache_size, 1))
        self._cache: Dict[str, V] = {}
        self._hash_functions = hash_functions
        self._cache_lock = Lock()

    def get_hash_key(self, **kwargs) -> str:
        hash_chunks = []
        for key_name, hashing_function in self._hash_functions:
            if key_name not in kwargs:
                raise WorkflowEnvironmentConfigurationError(
                    public_message=f"Cache is miss configured.",
                    context="workflows_cache | hash_key_generation",
                )
            hash_value = hashing_function(kwargs[key_name])
            hash_chunks.append(hash_value)
        return hashlib.md5("<|>".join(hash_chunks).encode("utf-8")).hexdigest()

    def get(self, key: str) -> Optional[V]:
        with self._cache_lock:
            return self._cache.get(key)

    def cache(self, key: str, value: V) -> None:
        with self._cache_lock:
            if len(self._keys_buffer) == self._keys_buffer.maxlen:
                to_pop = self._keys_buffer.popleft()
                del self._cache[to_pop]
            self._keys_buffer.append(key)
            self._cache[key] = value

inference.core.workflows.execution_engine.v1.compiler.graph_constructor

Functions

establish_step_execution_dimensionality

establish_step_execution_dimensionality(
    inputs_dimensionalities,
    control_flow_lineage_support,
    output_dimensionality_offset,
)

Determine how many batch dimensions (execution slices) a step runs with.

Used during workflow compilation in denote_data_flow_for_step. The result is stored on StepNode.step_execution_dimensionality and consumed at execution time to: - Drive how many times the step is executed (which batch indices/slices). - Align and expand inputs (e.g. auto-batch casting) to match this size. - Validate that parameter dimensionalities are compatible (runtime checks in step_input_assembler and manager).

Logic: - If no input has non-zero dimensionality but the step is gated by control flow (control_flow_lineage_support non-empty), the dimensionality is the number of control-flow branches. - Otherwise, the minimum non-zero input dimensionality is used; if output_dimensionality_offset < 0 (step reduces batch dimension), one is subtracted.

Parameters:

Name Type Description Default
inputs_dimensionalities Dict[str, Set[int]]

Per-input sets of dimensionalities (from get_inputs_dimensionalities).

required
control_flow_lineage_support List[str]

Lineage identifiers for control-flow branches that gate this step (from establish_batch_oriented_step_lineage).

required
output_dimensionality_offset int

Block's output dimensionality offset (positive = expand, negative = reduce batch dimension).

required

Returns:

Type Description
int

The number of batch dimensions (execution slices) for this step.

Source code in inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
def establish_step_execution_dimensionality(
    inputs_dimensionalities: Dict[str, Set[int]],
    control_flow_lineage_support: List[str],
    output_dimensionality_offset: int,
) -> int:
    """
    Determine how many batch dimensions (execution slices) a step runs with.

    Used during workflow compilation in denote_data_flow_for_step. The result
    is stored on StepNode.step_execution_dimensionality and consumed at
    execution time to:
    - Drive how many times the step is executed (which batch indices/slices).
    - Align and expand inputs (e.g. auto-batch casting) to match this size.
    - Validate that parameter dimensionalities are compatible (runtime checks
      in step_input_assembler and manager).

    Logic:
    - If no input has non-zero dimensionality but the step is gated by
      control flow (control_flow_lineage_support non-empty), the
      dimensionality is the number of control-flow branches.
    - Otherwise, the minimum non-zero input dimensionality is used; if
      output_dimensionality_offset < 0 (step reduces batch dimension),
      one is subtracted.

    Args:
        inputs_dimensionalities: Per-input sets of dimensionalities (from
            get_inputs_dimensionalities).
        control_flow_lineage_support: Lineage identifiers for control-flow
            branches that gate this step (from establish_batch_oriented_step_lineage).
        output_dimensionality_offset: Block's output dimensionality offset
            (positive = expand, negative = reduce batch dimension).

    Returns:
        The number of batch dimensions (execution slices) for this step.
    """
    step_execution_dimensionality = 0
    non_zero_dimensionalities = {
        dimensionality
        for dimensionalities in inputs_dimensionalities.values()
        for dimensionality in dimensionalities
        if dimensionality > 0
    }
    if len(non_zero_dimensionalities) == 0 and len(control_flow_lineage_support) > 0:
        return len(control_flow_lineage_support)
    if len(non_zero_dimensionalities) > 0:
        step_execution_dimensionality = min(non_zero_dimensionalities)
        if output_dimensionality_offset < 0:
            step_execution_dimensionality -= 1
    return step_execution_dimensionality

get_lineage_derived_from_control_flow

get_lineage_derived_from_control_flow(
    control_flow_steps_selectors, execution_graph
)

Return unique non-empty data lineages from the given control flow steps.

Each lineage is taken from the step's data_lineage in the execution graph. Lineages are deduplicated by lineage id (see identify_lineage); empty lineages are omitted. Used when establishing batch-oriented step lineage.

Parameters:

Name Type Description Default
control_flow_steps_selectors List[str]

Step selectors (node ids) of control flow steps whose data_lineage is to be collected.

required
execution_graph DiGraph

The workflow execution graph containing step nodes and their data_lineage.

required

Returns:

Type Description
List[List[str]]

List of distinct non-empty data lineages, one per unique lineage id.

Source code in inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
def get_lineage_derived_from_control_flow(
    control_flow_steps_selectors: List[str],
    execution_graph: nx.DiGraph,
) -> List[List[str]]:
    """
    Return unique non-empty data lineages from the given control flow steps.

    Each lineage is taken from the step's data_lineage in the execution graph.
    Lineages are deduplicated by lineage id (see identify_lineage); empty
    lineages are omitted. Used when establishing batch-oriented step lineage.

    Args:
        control_flow_steps_selectors: Step selectors (node ids) of control
            flow steps whose data_lineage is to be collected.
        execution_graph: The workflow execution graph containing step nodes
            and their data_lineage.

    Returns:
        List of distinct non-empty data lineages, one per unique lineage id.
    """
    unique_lineages, _ = _collect_unique_control_flow_lineages_with_step_mapping(
        control_flow_steps_selectors=control_flow_steps_selectors,
        execution_graph=execution_graph,
    )
    return unique_lineages

verify_compatibility_of_input_data_lineage_with_control_flow_lineage

verify_compatibility_of_input_data_lineage_with_control_flow_lineage(
    step_name,
    inputs_lineage,
    control_flow_steps_selectors,
    execution_graph,
)

Ensure control flow steps' data lineage is compatible with the step's inputs.

Control flow steps that affect this step must operate on data that is compatible with the data fed to the step; otherwise the step could never execute. Compares unique control flow lineages against input lineage prefixes and raises ControlFlowDefinitionError if any control flow lineage is not covered by the inputs.

If inputs_lineage is empty, there is no sense to verify compatibility. The lineage of the step should be established based on the control flow lineages.

Parameters:

Name Type Description Default
step_name str

Name of the step being verified (used in error messages).

required
inputs_lineage List[List[str]]

Data lineages derived from the step's input data.

required
control_flow_steps_selectors List[str]

Step selectors of control flow steps that affect this step's execution.

required
execution_graph DiGraph

The workflow execution graph.

required

Raises:

Type Description
ControlFlowDefinitionError

When a control flow step's lineage is not compatible with the step's input lineage (step would never execute).

Source code in inference/core/workflows/execution_engine/v1/compiler/graph_constructor.py
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
def verify_compatibility_of_input_data_lineage_with_control_flow_lineage(
    step_name: str,
    inputs_lineage: List[List[str]],
    control_flow_steps_selectors: List[str],
    execution_graph: DiGraph,
) -> None:
    """
    Ensure control flow steps' data lineage is compatible with the step's inputs.

    Control flow steps that affect this step must operate on data that is
    compatible with the data fed to the step; otherwise the step could never
    execute. Compares unique control flow lineages against input lineage
    prefixes and raises ControlFlowDefinitionError if any control flow lineage
    is not covered by the inputs.

    If inputs_lineage is empty, there is no sense to verify compatibility. The lineage of the
    step should be established based on the control flow lineages.

    Args:
        step_name: Name of the step being verified (used in error messages).
        inputs_lineage: Data lineages derived from the step's input data.
        control_flow_steps_selectors: Step selectors of control flow steps
            that affect this step's execution.
        execution_graph: The workflow execution graph.

    Raises:
        ControlFlowDefinitionError: When a control flow step's lineage is not
            compatible with the step's input lineage (step would never execute).
    """
    (
        batch_oriented_control_flow_lineages,
        lineage_id2control_flow_steps,
    ) = _collect_unique_control_flow_lineages_with_step_mapping(
        control_flow_steps_selectors=control_flow_steps_selectors,
        execution_graph=execution_graph,
    )
    if not inputs_lineage:
        return

    all_input_lineage_prefixes = get_all_batch_lineage_prefixes(lineages=inputs_lineage)
    all_input_lineage_prefixes_hashes = {
        identify_lineage(lineage=lineage) for lineage in all_input_lineage_prefixes
    }
    for control_flow_lineage in batch_oriented_control_flow_lineages:
        control_flow_lineage_id = identify_lineage(lineage=control_flow_lineage)
        if control_flow_lineage_id not in all_input_lineage_prefixes_hashes:
            problematic_flow_control_steps = lineage_id2control_flow_steps[
                control_flow_lineage_id
            ]
            raise ControlFlowDefinitionError(
                public_message=f"Step {step_name} execution is impacted by control flow outcome of the following "
                f"steps {problematic_flow_control_steps} which make decision based on data that is "
                f"not compatible with data fed to the step {step_name} - which would cause the step "
                f"to never execute. This behaviour is invalid and prevented upfront by Workflows compiler.",
                context="workflow_compilation | execution_graph_construction | verification_of_control_flow_lineage",
            )

inference.core.workflows.execution_engine.v1.compiler.graph_traversal

Functions

traverse_graph_ensuring_parents_are_reached_first

traverse_graph_ensuring_parents_are_reached_first(
    graph, start_node
)

This function works under assumption of common super-input node in the graph - otherwise, there is no common entry point to put as start_node.

Source code in inference/core/workflows/execution_engine/v1/compiler/graph_traversal.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def traverse_graph_ensuring_parents_are_reached_first(
    graph: DiGraph,
    start_node: str,
) -> List[str]:
    """
    This function works under assumption of common super-input node in the graph - otherwise,
    there is no common entry point to put as `start_node`.
    """
    graph_copy = graph.copy()
    distance_key = "distance"
    graph_copy = assign_max_distances_from_start(
        graph=graph_copy,
        start_node=start_node,
        distance_key=distance_key,
    )
    nodes_groups = group_nodes_by_sorted_key_value(graph=graph_copy, key=distance_key)
    return [node for node_group in nodes_groups for node in node_group]

inference.core.workflows.execution_engine.v1.compiler.syntactic_parser

Classes

Functions

clear_cache

clear_cache()

Clear the workflow schema cache.

Source code in inference/core/workflows/execution_engine/v1/compiler/syntactic_parser.py
147
148
149
def clear_cache() -> None:
    """Clear the workflow schema cache."""
    _cached_workflow_schema.cache_clear()

core/workflows/execution_engine/v1/dynamic_blocks

inference.core.workflows.execution_engine.v1.dynamic_blocks.block_assembler

Functions

ensure_dynamic_blocks_allowed

ensure_dynamic_blocks_allowed(dynamic_blocks_definitions)

Ensure that dynamic blocks are allowed based on configuration.

Dynamic blocks are allowed if: 1. Local custom Python execution is enabled (ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS=True) 2. OR Modal execution mode is set (WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE=modal)

This allows secure execution via Modal sandboxes even when local execution is disabled.

Source code in inference/core/workflows/execution_engine/v1/dynamic_blocks/block_assembler.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def ensure_dynamic_blocks_allowed(dynamic_blocks_definitions: List[dict]) -> None:
    """Ensure that dynamic blocks are allowed based on configuration.

    Dynamic blocks are allowed if:
    1. Local custom Python execution is enabled (ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS=True)
    2. OR Modal execution mode is set (WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE=modal)

    This allows secure execution via Modal sandboxes even when local execution is disabled.
    """
    if not dynamic_blocks_definitions:
        return

    # Check if we're using Modal for secure remote execution
    is_modal_mode = WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE == "modal"

    # Allow if either local execution is enabled OR Modal mode is set
    if not ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS and not is_modal_mode:
        raise WorkflowEnvironmentConfigurationError(
            public_message="Cannot use dynamic blocks with custom Python code in this installation of `workflows`. "
            "This can be changed by either setting environmental variable "
            "`ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS=True` for local execution "
            "or `WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE=modal` for secure remote execution.",
            context="workflow_compilation | dynamic_blocks_compilation",
        )

inference.core.workflows.execution_engine.v1.dynamic_blocks.modal_executor

Modal executor for Custom Python Blocks in Workflows using Web Endpoints.

This module handles the execution of untrusted user code in Modal sandboxes using web endpoints for better security and no size limitations.

Classes

ModalExecutor

Manages execution of Custom Python Blocks in Modal sandboxes via web endpoints.

Source code in inference/core/workflows/execution_engine/v1/dynamic_blocks/modal_executor.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class ModalExecutor:
    """Manages execution of Custom Python Blocks in Modal sandboxes via web endpoints."""

    def __init__(self, workspace_id: Optional[str] = None):
        """Initialize the Modal executor for a specific workspace.

        Args:
            workspace_id: The workspace ID to namespace execution, defaults to "anonymous"
        """
        self.workspace_id = workspace_id or MODAL_ANONYMOUS_WORKSPACE_NAME
        self._base_url = None

    def _get_endpoint_url(self, workspace_id: str) -> str:
        """Get the web endpoint URL for a workspace.

        Args:
            workspace_id: The workspace ID

        Returns:
            The endpoint URL with query parameter for workspace_id
        """
        # Get base URL once (it's the same for all workspace_ids)
        if self._base_url is None:
            # First check for environment variable override
            env_url = os.environ.get("MODAL_WEB_ENDPOINT_URL")
            if env_url:
                self._base_url = env_url
            else:
                # If we couldn't get it dynamically, construct it based on expected pattern
                if not self._base_url:
                    # URL pattern: https://{workspace}--{app}-{class}-{method_truncated}.modal.run
                    # Note: Modal truncates long labels to 63 chars with a hash suffix
                    workspace = MODAL_WORKSPACE_NAME
                    app_name = "webexec"
                    class_name = "executor"
                    method_name = "execute-block"

                    # The label would be: inference-custom-blocks-web-customblockexecutor-execute-block
                    # This is 62 chars, which might get truncated
                    label = f"{app_name}-{class_name}-{method_name}"
                    if (
                        len(label) > 56
                    ):  # Modal truncates at 56 chars and adds 7-char hash
                        import hashlib

                        hash_str = hashlib.sha256(label.encode()).hexdigest()[:6]
                        label = f"{label[:56]}-{hash_str}"

                    self._base_url = f"https://{workspace}--{label}.modal.run"

        # Add workspace_id as query parameter
        return f"{self._base_url}?workspace_id={workspace_id}"

    def execute_remote(
        self,
        block_type_name: str,
        python_code: PythonCode,
        inputs: Dict[str, Any],
        workspace_id: Optional[str] = None,
    ) -> BlockResult:
        """Execute a Custom Python Block in a Modal sandbox via web endpoint.

        Args:
            block_type_name: Name of the block type
            python_code: The Python code to execute
            inputs: Input data for the function
            workspace_id: Optional workspace ID override

        Returns:
            BlockResult from the execution

        Raises:
            DynamicBlockError: If Modal is not available or Modal request fails
            Exception: If remote execution throws an exception
        """
        # Check if Modal is available
        if not MODAL_AVAILABLE:
            raise DynamicBlockError(
                public_message="Modal credentials not configured. Please set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables.",
                context="modal_executor | credentials_check",
            )

        # Use provided workspace_id or fall back to instance default
        workspace = workspace_id if workspace_id else self.workspace_id

        try:
            # Get endpoint URL for this workspace
            endpoint_url = self._get_endpoint_url(workspace)

            # Custom JSON encoder for inputs
            inputs_json = serialize_for_modal_remote_execution(inputs)

            # Prepare request payload
            request_payload = {
                "code_str": python_code.run_function_code,
                "imports": python_code.imports or [],
                "run_function_name": python_code.run_function_name,
                "inputs_json": inputs_json,
            }

            if (
                not workspace
                or workspace == "anonymous"
                or workspace == "unauthorized"
                or workspace == MODAL_ANONYMOUS_WORKSPACE_NAME
            ):
                from inference.core.env import MODAL_ALLOW_ANONYMOUS_EXECUTION

                if not MODAL_ALLOW_ANONYMOUS_EXECUTION:
                    raise DynamicBlockError(
                        public_message="Modal validation requires an API key when anonymous execution is disabled. "
                        "Please provide an API key or enable anonymous execution by setting "
                        "MODAL_ALLOW_ANONYMOUS_EXECUTION=True",
                        context="modal_executor | validation_authentication",
                    )

            # Make HTTP request to Modal endpoint
            response = requests.post(
                endpoint_url,
                json=request_payload,
                timeout=30,  # 30 second timeout
                headers={
                    "Content-Type": "application/json",
                    "Modal-Key": MODAL_TOKEN_ID,
                    "Modal-Secret": MODAL_TOKEN_SECRET,
                },
            )

            # Check HTTP status
            if response.status_code != 200:
                raise DynamicBlockError(
                    public_message=f"Modal endpoint returned status {response.status_code}: {response.text}",
                    context="modal_executor | http_request",
                )

            # Parse response
            result = response.json()

            # Check for errors
            if not result.get("success", False):
                error_msg = result.get("error", "Unknown error")
                error_type = result.get("error_type", "RuntimeError")
                line_number = result.get("line_number", None)
                function_name = result.get("function_name", None)

                if line_number and function_name:
                    message = f"Error in line {line_number}, in {function_name}: {error_type}: {error_msg}"
                else:
                    message = f"{error_type}: {error_msg}"

                # Propagate remote Exception on runtime error. Will be caught by the
                # core executor and wrapped in StepExecutionError with block metadata.
                raise Exception(message)

            # Get the result and deserialize from JSON
            json_result = result.get("result", "{}")
            return deserialize_for_modal_remote_execution(json_result)

        except requests.exceptions.RequestException as e:
            raise DynamicBlockError(
                public_message=f"Failed to connect to Modal endpoint: {str(e)}",
                context="modal_executor | http_connection",
            )
Functions
__init__
__init__(workspace_id=None)

Initialize the Modal executor for a specific workspace.

Parameters:

Name Type Description Default
workspace_id Optional[str]

The workspace ID to namespace execution, defaults to "anonymous"

None
Source code in inference/core/workflows/execution_engine/v1/dynamic_blocks/modal_executor.py
181
182
183
184
185
186
187
188
def __init__(self, workspace_id: Optional[str] = None):
    """Initialize the Modal executor for a specific workspace.

    Args:
        workspace_id: The workspace ID to namespace execution, defaults to "anonymous"
    """
    self.workspace_id = workspace_id or MODAL_ANONYMOUS_WORKSPACE_NAME
    self._base_url = None
execute_remote
execute_remote(
    block_type_name, python_code, inputs, workspace_id=None
)

Execute a Custom Python Block in a Modal sandbox via web endpoint.

Parameters:

Name Type Description Default
block_type_name str

Name of the block type

required
python_code PythonCode

The Python code to execute

required
inputs Dict[str, Any]

Input data for the function

required
workspace_id Optional[str]

Optional workspace ID override

None

Returns:

Type Description
BlockResult

BlockResult from the execution

Raises:

Type Description
DynamicBlockError

If Modal is not available or Modal request fails

Exception

If remote execution throws an exception

Source code in inference/core/workflows/execution_engine/v1/dynamic_blocks/modal_executor.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def execute_remote(
    self,
    block_type_name: str,
    python_code: PythonCode,
    inputs: Dict[str, Any],
    workspace_id: Optional[str] = None,
) -> BlockResult:
    """Execute a Custom Python Block in a Modal sandbox via web endpoint.

    Args:
        block_type_name: Name of the block type
        python_code: The Python code to execute
        inputs: Input data for the function
        workspace_id: Optional workspace ID override

    Returns:
        BlockResult from the execution

    Raises:
        DynamicBlockError: If Modal is not available or Modal request fails
        Exception: If remote execution throws an exception
    """
    # Check if Modal is available
    if not MODAL_AVAILABLE:
        raise DynamicBlockError(
            public_message="Modal credentials not configured. Please set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables.",
            context="modal_executor | credentials_check",
        )

    # Use provided workspace_id or fall back to instance default
    workspace = workspace_id if workspace_id else self.workspace_id

    try:
        # Get endpoint URL for this workspace
        endpoint_url = self._get_endpoint_url(workspace)

        # Custom JSON encoder for inputs
        inputs_json = serialize_for_modal_remote_execution(inputs)

        # Prepare request payload
        request_payload = {
            "code_str": python_code.run_function_code,
            "imports": python_code.imports or [],
            "run_function_name": python_code.run_function_name,
            "inputs_json": inputs_json,
        }

        if (
            not workspace
            or workspace == "anonymous"
            or workspace == "unauthorized"
            or workspace == MODAL_ANONYMOUS_WORKSPACE_NAME
        ):
            from inference.core.env import MODAL_ALLOW_ANONYMOUS_EXECUTION

            if not MODAL_ALLOW_ANONYMOUS_EXECUTION:
                raise DynamicBlockError(
                    public_message="Modal validation requires an API key when anonymous execution is disabled. "
                    "Please provide an API key or enable anonymous execution by setting "
                    "MODAL_ALLOW_ANONYMOUS_EXECUTION=True",
                    context="modal_executor | validation_authentication",
                )

        # Make HTTP request to Modal endpoint
        response = requests.post(
            endpoint_url,
            json=request_payload,
            timeout=30,  # 30 second timeout
            headers={
                "Content-Type": "application/json",
                "Modal-Key": MODAL_TOKEN_ID,
                "Modal-Secret": MODAL_TOKEN_SECRET,
            },
        )

        # Check HTTP status
        if response.status_code != 200:
            raise DynamicBlockError(
                public_message=f"Modal endpoint returned status {response.status_code}: {response.text}",
                context="modal_executor | http_request",
            )

        # Parse response
        result = response.json()

        # Check for errors
        if not result.get("success", False):
            error_msg = result.get("error", "Unknown error")
            error_type = result.get("error_type", "RuntimeError")
            line_number = result.get("line_number", None)
            function_name = result.get("function_name", None)

            if line_number and function_name:
                message = f"Error in line {line_number}, in {function_name}: {error_type}: {error_msg}"
            else:
                message = f"{error_type}: {error_msg}"

            # Propagate remote Exception on runtime error. Will be caught by the
            # core executor and wrapped in StepExecutionError with block metadata.
            raise Exception(message)

        # Get the result and deserialize from JSON
        json_result = result.get("result", "{}")
        return deserialize_for_modal_remote_execution(json_result)

    except requests.exceptions.RequestException as e:
        raise DynamicBlockError(
            public_message=f"Failed to connect to Modal endpoint: {str(e)}",
            context="modal_executor | http_connection",
        )

Functions

validate_code_in_modal

validate_code_in_modal(python_code, workspace_id=None)

Validate Python code syntax in a Modal sandbox via web endpoint.

Parameters:

Name Type Description Default
python_code PythonCode

The Python code to validate

required
workspace_id Optional[str]

The workspace ID for Modal App

None

Returns:

Type Description
bool

True if code is valid, raises otherwise

Raises:

Type Description
DynamicBlockError

If code validation fails

Source code in inference/core/workflows/execution_engine/v1/dynamic_blocks/modal_executor.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def validate_code_in_modal(
    python_code: PythonCode, workspace_id: Optional[str] = None
) -> bool:
    """Validate Python code syntax in a Modal sandbox via web endpoint.

    Args:
        python_code: The Python code to validate
        workspace_id: The workspace ID for Modal App

    Returns:
        True if code is valid, raises otherwise

    Raises:
        DynamicBlockError: If code validation fails
    """
    # Check if Modal is available
    if not MODAL_AVAILABLE:
        raise DynamicBlockError(
            public_message="Modal credentials not configured. Please set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables.",
            context="modal_executor | credentials_check",
        )

    workspace = workspace_id or MODAL_ANONYMOUS_WORKSPACE_NAME

    # Construct the full code to validate (same as in create_dynamic_module)
    full_code = python_code.run_function_code
    if python_code.init_function_code:
        full_code += "\n\n" + python_code.init_function_code

    # Escape the code for safe embedding in the validation function
    # Use repr() to properly escape quotes and special characters
    escaped_code = repr(full_code)

    # Simple validation code that checks syntax
    validation_code = PythonCode(
        type="PythonCode",
        imports=[],
        run_function_code=f"""
import ast

def validate_syntax():
    try:
        # Try to compile the user code
        code = {escaped_code}
        compile(code, "<string>", "exec")
        # Try to parse as AST to check structure
        ast.parse(code)
        return {{"valid": True}}
    except SyntaxError as e:
        return {{"valid": False, "error": str(e), "line": e.lineno}}
    except Exception as e:
        return {{"valid": False, "error": str(e)}}
""",
        run_function_name="validate_syntax",
        init_function_code=None,
        init_function_name="init",
    )

    executor = ModalExecutor(workspace_id=workspace)

    try:
        # For validation, we don't need complex inputs, just pass empty JSON
        result = executor.execute_remote(
            block_type_name="validation",
            python_code=validation_code,
            inputs={},
            workspace_id=workspace,
        )

        if result.get("valid") is False:
            error_msg = result.get("error", "Unknown syntax error")
            line_no = result.get("line", None)
            if line_no:
                error_msg = f"Line {line_no}: {error_msg}"
            raise DynamicBlockError(
                public_message=f"Code validation failed: {error_msg}",
                context="modal_executor | code_validation",
            )

        return True

    except Exception as e:
        if isinstance(e, DynamicBlockError):
            raise
        raise DynamicBlockError(
            public_message=f"Code validation failed: {str(e)}",
            context="modal_executor | code_validation",
        )

core/workflows/execution_engine/v1/executor/execution_data_manager

inference.core.workflows.execution_engine.v1.executor.execution_data_manager.step_input_assembler

Functions

filter_to_valid_prefix_chains

filter_to_valid_prefix_chains(per_dim_sets, dimensions)

Keep only indices that form a complete parent-child chain across dimensions.

Given per-dimension sets (e.g. from intersect_masks_per_dimension), retains only indices that have a full lineage from the smallest to the largest dimension. Used for inter-level intersection.

Source code in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def filter_to_valid_prefix_chains(
    per_dim_sets: Dict[int, Set[DynamicBatchIndex]],
    dimensions: Set[int],
) -> Dict[int, Set[DynamicBatchIndex]]:
    """Keep only indices that form a complete parent-child chain across dimensions.

    Given per-dimension sets (e.g. from intersect_masks_per_dimension), retains
    only indices that have a full lineage from the smallest to the largest
    dimension. Used for inter-level intersection.
    """
    sorted_dims = sorted(dimensions)
    by_dim: Dict[int, Set[DynamicBatchIndex]] = {
        dim: per_dim_sets.get(dim, set()) for dim in sorted_dims
    }

    if len(sorted_dims) <= 1:
        return dict(by_dim)

    prev_dim = {sorted_dims[i]: sorted_dims[i - 1] for i in range(1, len(sorted_dims))}

    # Bottom-up: mark indices that have at least one descendant
    has_child: Set[DynamicBatchIndex] = set()
    for dim in reversed(sorted_dims):
        for idx in by_dim[dim]:
            if dim == sorted_dims[-1] or idx in has_child:
                parent = idx[:-1]
                if parent:
                    has_child.add(parent)

    # Top-down: keep indices only if full prefix chain exists
    valid: Dict[int, Set[DynamicBatchIndex]] = {dim: set() for dim in sorted_dims}
    for dim in sorted_dims:
        for idx in by_dim[dim]:
            parent = idx[:-1]
            if dim == sorted_dims[0]:
                if idx in has_child:
                    valid[dim].add(idx)
            elif parent in valid[prev_dim[dim]]:
                if dim == sorted_dims[-1] or idx in has_child:
                    valid[dim].add(idx)

    return valid

get_masks_intersection_for_dimensions

get_masks_intersection_for_dimensions(
    batch_masks, dimensions
)

Intersect masks at each dimension and filter to valid prefix chains.

Source code in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def get_masks_intersection_for_dimensions(
    batch_masks: List[Set[DynamicBatchIndex]],
    dimensions: Set[int],
) -> Dict[int, Optional[Set[DynamicBatchIndex]]]:
    """Intersect masks at each dimension and filter to valid prefix chains."""
    if not batch_masks:
        return {dim: None for dim in dimensions}

    sorted_dims = sorted(dimensions)

    if len(sorted_dims) <= 1:
        result = intersect_masks_per_dimension(batch_masks, dimensions)
        return {dim: result[dim] for dim in sorted_dims}

    per_dim = intersect_masks_per_dimension(batch_masks, dimensions)
    return filter_to_valid_prefix_chains(per_dim, dimensions)

intersect_masks_per_dimension

intersect_masks_per_dimension(batch_masks, dimensions)

Intersect masks at each dimensionality level.

For each dimension d, returns the set of indices (with length d) that appear in every mask that has at least one index at that dimension. Masks with no indices at d are ignored for that dimension. Used for intra-dimensional intersection.

Source code in inference/core/workflows/execution_engine/v1/executor/execution_data_manager/step_input_assembler.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def intersect_masks_per_dimension(
    batch_masks: List[Set[DynamicBatchIndex]],
    dimensions: Set[int],
) -> Dict[int, Set[DynamicBatchIndex]]:
    """Intersect masks at each dimensionality level.

    For each dimension d, returns the set of indices (with length d) that appear
    in every mask that has at least one index at that dimension. Masks with no
    indices at d are ignored for that dimension. Used for intra-dimensional
    intersection.
    """
    sorted_dims = sorted(dimensions)
    result: Dict[int, Set[DynamicBatchIndex]] = {}
    for dim in sorted_dims:
        sets_at_dim = [{idx for idx in mask if len(idx) == dim} for mask in batch_masks]
        non_empty = [s for s in sets_at_dim if s]
        result[dim] = set.intersection(*non_empty) if non_empty else set()
    return result

enterprise/parallel

Parallel HTTP inference via Celery workers for high-throughput deployments.

inference.enterprise.parallel.dispatch_manager

Classes

ResultsChecker

Class responsible for queuing asyncronous inference runs, keeping track of running requests, and awaiting their results.

Source code in inference/enterprise/parallel/dispatch_manager.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class ResultsChecker:
    """
    Class responsible for queuing asyncronous inference runs,
    keeping track of running requests, and awaiting their results.
    """

    def __init__(self, redis: Redis):
        self.tasks: Dict[str, Event] = {}
        self.dones = dict()
        self.errors = dict()
        self.running = True
        self.redis = redis
        self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS)

    def add_task(self, task_id: str, request: InferenceRequest):
        """
        Wait until there's available cylce to queue a task.
        When there are cycles, add the task's id to a list to keep track of its results,
        launch the preprocess celeryt task, set the task's status to in progress in redis.
        """
        self.semaphore.acquire()
        self.tasks[task_id] = Event()
        preprocess.s(request.dict()).delay()

    def get_result(self, task_id: str) -> Any:
        """
        Check the done tasks and errored tasks for this task id.
        """
        if task_id in self.dones:
            return self.dones.pop(task_id)
        elif task_id in self.errors:
            message = self.errors.pop(task_id)
            raise Exception(message)
        else:
            raise RuntimeError(
                "Task result not found in either success or error dict. Unreachable"
            )

    def loop(self):
        """
        Main loop. Check all in progress tasks for their status, and if their status is final,
        (either failure or success) then add their results to the appropriate results dictionary.
        """
        with self.redis.pubsub() as pubsub:
            pubsub.subscribe("results")
            for message in pubsub.listen():
                if message["type"] != "message":
                    continue
                message = orjson.loads(message["data"])
                task_id = message.pop("task_id")
                if task_id not in self.tasks:
                    continue
                self.semaphore.release()
                status = message.pop("status")
                if status == FAILURE_STATE:
                    self.errors[task_id] = message["payload"]
                elif status == SUCCESS_STATE:
                    self.dones[task_id] = message["payload"]
                else:
                    raise RuntimeError(
                        "Task result not found in possible states. Unreachable"
                    )
                self.tasks[task_id].set()

    def wait_for_response(self, key: str):
        event = self.tasks[key]
        event.wait()
        del self.tasks[key]
        return self.get_result(key)
Functions
add_task
add_task(task_id, request)

Wait until there's available cylce to queue a task. When there are cycles, add the task's id to a list to keep track of its results, launch the preprocess celeryt task, set the task's status to in progress in redis.

Source code in inference/enterprise/parallel/dispatch_manager.py
36
37
38
39
40
41
42
43
44
def add_task(self, task_id: str, request: InferenceRequest):
    """
    Wait until there's available cylce to queue a task.
    When there are cycles, add the task's id to a list to keep track of its results,
    launch the preprocess celeryt task, set the task's status to in progress in redis.
    """
    self.semaphore.acquire()
    self.tasks[task_id] = Event()
    preprocess.s(request.dict()).delay()
get_result
get_result(task_id)

Check the done tasks and errored tasks for this task id.

Source code in inference/enterprise/parallel/dispatch_manager.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_result(self, task_id: str) -> Any:
    """
    Check the done tasks and errored tasks for this task id.
    """
    if task_id in self.dones:
        return self.dones.pop(task_id)
    elif task_id in self.errors:
        message = self.errors.pop(task_id)
        raise Exception(message)
    else:
        raise RuntimeError(
            "Task result not found in either success or error dict. Unreachable"
        )
loop
loop()

Main loop. Check all in progress tasks for their status, and if their status is final, (either failure or success) then add their results to the appropriate results dictionary.

Source code in inference/enterprise/parallel/dispatch_manager.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def loop(self):
    """
    Main loop. Check all in progress tasks for their status, and if their status is final,
    (either failure or success) then add their results to the appropriate results dictionary.
    """
    with self.redis.pubsub() as pubsub:
        pubsub.subscribe("results")
        for message in pubsub.listen():
            if message["type"] != "message":
                continue
            message = orjson.loads(message["data"])
            task_id = message.pop("task_id")
            if task_id not in self.tasks:
                continue
            self.semaphore.release()
            status = message.pop("status")
            if status == FAILURE_STATE:
                self.errors[task_id] = message["payload"]
            elif status == SUCCESS_STATE:
                self.dones[task_id] = message["payload"]
            else:
                raise RuntimeError(
                    "Task result not found in possible states. Unreachable"
                )
            self.tasks[task_id].set()

Functions

inference.enterprise.parallel.infer

Classes

Functions

get_batch

get_batch(redis, model_names)

Run a heuristic to select the best batch to infer on redis[Redis]: redis client model_names[List[str]]: list of models with nonzero number of requests returns: Tuple[List[Dict], str] List[Dict] represents a batch of request dicts str is the model id

Source code in inference/enterprise/parallel/infer.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]:
    """
    Run a heuristic to select the best batch to infer on
    redis[Redis]: redis client
    model_names[List[str]]: list of models with nonzero number of requests
    returns:
        Tuple[List[Dict], str]
        List[Dict] represents a batch of request dicts
        str is the model id
    """
    batch_sizes = [
        RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"]
        for m in model_names
    ]
    batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes]
    batches = [
        redis.zrange(f"infer:{m}", 0, b - 1, withscores=True)
        for m, b in zip(model_names, batch_sizes)
    ]
    model_index = select_best_inference_batch(batches, batch_sizes)
    batch = batches[model_index]
    selected_model = model_names[model_index]
    redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch])
    redis.hincrby(f"requests", selected_model, -len(batch))
    batch = [orjson.loads(b[0]) for b in batch]
    return batch, selected_model

write_infer_arrays_and_launch_postprocess

write_infer_arrays_and_launch_postprocess(
    arrs, request, preproc_return_metadata
)

Write inference results to shared memory and launch the postprocessing task

Source code in inference/enterprise/parallel/infer.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def write_infer_arrays_and_launch_postprocess(
    arrs: Tuple[np.ndarray, ...],
    request: InferenceRequest,
    preproc_return_metadata: Dict,
):
    """Write inference results to shared memory and launch the postprocessing task"""
    shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs]
    with shm_manager(*shms):
        shm_metadatas = []
        for arr, shm in zip(arrs, shms):
            shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
            shared[:] = arr[:]
            shm_metadata = SharedMemoryMetadata(
                shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name
            )
            shm_metadatas.append(asdict(shm_metadata))

        postprocess.s(
            tuple(shm_metadatas), request.dict(), preproc_return_metadata
        ).delay()

inference.enterprise.parallel.utils

Classes

SharedMemoryMetadata dataclass

Info needed to load array from shared memory

Source code in inference/enterprise/parallel/utils.py
64
65
66
67
68
69
70
@dataclass
class SharedMemoryMetadata:
    """Info needed to load array from shared memory"""

    shm_name: str
    array_shape: List[int]
    array_dtype: str

Functions

failure_handler

failure_handler(redis, *request_ids)

Context manager that updates the status/results key in redis with exception info on failure.

Source code in inference/enterprise/parallel/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@contextmanager
def failure_handler(redis: Redis, *request_ids: str):
    """
    Context manager that updates the status/results key in redis with exception
    info on failure.
    """
    try:
        yield
    except Exception as error:
        message = type(error).__name__ + ": " + str(error)
        for request_id in request_ids:
            redis.publish(
                "results",
                json.dumps(
                    {"task_id": request_id, "status": FAILURE_STATE, "payload": message}
                ),
            )
        raise

shm_manager

shm_manager(*shms, unlink_on_success=False)

Context manager that closes and frees shared memory objects.

Source code in inference/enterprise/parallel/utils.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@contextmanager
def shm_manager(
    *shms: Union[str, shared_memory.SharedMemory], unlink_on_success: bool = False
):
    """Context manager that closes and frees shared memory objects."""
    try:
        loaded_shms = []
        for shm in shms:
            errors = []
            try:
                if isinstance(shm, str):
                    shm = shared_memory.SharedMemory(name=shm)
                loaded_shms.append(shm)
            except BaseException as error:
                errors.append(error)
            if errors:
                raise Exception(errors)

        yield loaded_shms
    except:
        for shm in loaded_shms:
            shm.close()
            shm.unlink()
        raise
    else:
        for shm in loaded_shms:
            shm.close()
            if unlink_on_success:
                shm.unlink()

enterprise/workflows/enterprise_blocks/sinks/PLC_modbus

inference.enterprise.workflows.enterprise_blocks.sinks.PLC_modbus.v1

Classes

ModbusTCPBlockV1

Bases: WorkflowBlock

A Modbus TCP communication block using pymodbus.

Supports: - 'read': Reads specified registers. - 'write': Writes values to specified registers. - 'read_and_write': Reads and writes in one execution.

On failures, errors are printed and marked as "ReadFailure" or "WriteFailure".

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/PLC_modbus/v1.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class ModbusTCPBlockV1(WorkflowBlock):
    """A Modbus TCP communication block using pymodbus.

    Supports:
    - 'read': Reads specified registers.
    - 'write': Writes values to specified registers.
    - 'read_and_write': Reads and writes in one execution.

    On failures, errors are printed and marked as "ReadFailure" or "WriteFailure".
    """

    def __init__(self):
        self.client: Optional[ModbusClient] = None

    def __del__(self):
        if self.client:
            try:
                self.client.close()
            except Exception as exc:
                logger.debug("Failed to release modbus client: %s", exc)

    @classmethod
    def get_manifest(cls) -> Type[WorkflowBlockManifest]:
        return ModbusTCPBlockManifest

    def run(
        self,
        plc_ip: str,
        plc_port: int,
        mode: str,
        registers_to_read: List[int],
        registers_to_write: Dict[int, int],
        depends_on: any,
        image: Optional[WorkflowImageData] = None,
        metadata: Optional[VideoMetadata] = None,
    ) -> dict:
        read_results = {}
        write_results = {}

        if not self.client:
            self.client: ModbusClient = ModbusClient(plc_ip, port=plc_port)
            if not self.client.connect():
                print("Failed to connect to PLC")
                return {"modbus_results": [{"error": "ConnectionFailure"}]}

        # If mode involves reading
        if mode in ["read", "read_and_write"]:
            for address in registers_to_read:
                try:
                    response = self.client.read_holding_registers(address)
                    if not response.isError():
                        read_results[address] = (
                            response.registers[0] if response.registers else None
                        )
                    else:
                        print(f"Error reading register {address}: {response}")
                        read_results[address] = "ReadFailure"
                except Exception as e:
                    print(f"Exception reading register {address}: {e}")
                    read_results[address] = "ReadFailure"

        # If mode involves writing
        if mode in ["write", "read_and_write"]:
            for address, value in registers_to_write.items():
                try:
                    response = self.client.write_register(address, value)
                    if not response.isError():
                        write_results[address] = "WriteSuccess"
                    else:
                        print(
                            f"Error writing register {address} with value {value}: {response}"
                        )
                        write_results[address] = "WriteFailure"
                except Exception as e:
                    print(
                        f"Exception writing register {address} with value {value}: {e}"
                    )
                    write_results[address] = "WriteFailure"

        modbus_output = {}
        if read_results:
            modbus_output["read"] = read_results
        if write_results:
            modbus_output["write"] = write_results

        return {"modbus_results": [modbus_output]}

enterprise/workflows/enterprise_blocks/sinks/PLCethernetIP

inference.enterprise.workflows.enterprise_blocks.sinks.PLCethernetIP.v1

Classes

PLCBlockManifest

Bases: WorkflowBlockManifest

Manifest for a PLC communication block using Ethernet/IP.

The block can be used in one of three modes: - 'read': Only reads specified tags. - 'write': Only writes specified tags. - 'read_and_write': Performs both reading and writing in one execution.

tags_to_read and tags_to_write are applicable depending on the mode chosen.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/PLCethernetIP/v1.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class PLCBlockManifest(WorkflowBlockManifest):
    """Manifest for a PLC communication block using Ethernet/IP.

    The block can be used in one of three modes:
    - 'read': Only reads specified tags.
    - 'write': Only writes specified tags.
    - 'read_and_write': Performs both reading and writing in one execution.

    `tags_to_read` and `tags_to_write` are applicable depending on the mode chosen.
    """

    model_config = ConfigDict(
        json_schema_extra={
            "name": "PLC EthernetIP",
            "version": "v1",
            "short_description": "Generic PLC read/write block using pylogix over Ethernet/IP.",
            "long_description": LONG_DESCRIPTION,
            "license": "Roboflow Enterprise License",
            "block_type": "sinks",
            "ui_manifest": {
                "section": "industrial",
                "icon": "fal fa-microchip",
                "blockPriority": 13,
                "enterprise_only": True,
                "local_only": True,
            },
        }
    )

    type: Literal["roboflow_core/sinks@v1"]

    plc_ip: Union[str, WorkflowParameterSelector(kind=[STRING_KIND])] = Field(
        description="IP address of the target PLC.", examples=["192.168.1.10"]
    )

    mode: Literal["read", "write", "read_and_write"] = Field(
        description="Mode of operation: 'read', 'write', or 'read_and_write'.",
        examples=["read", "write", "read_and_write"],
    )

    tags_to_read: Union[
        List[str],
        Selector(kind=[LIST_OF_VALUES_KIND]),
        WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]),
    ] = Field(
        default=[],
        description="List of PLC tag names to read. Applicable if mode='read' or mode='read_and_write'.",
        examples=[["camera_msg", "sku_number"]],
    )

    tags_to_write: Union[
        Dict[str, Union[int, float, str]],
        Selector(kind=[DICTIONARY_KIND]),
        WorkflowParameterSelector(kind=[DICTIONARY_KIND]),
    ] = Field(
        default={},
        description="Dictionary of tags and the values to write. Applicable if mode='write' or mode='read_and_write'.",
        examples=[{"camera_fault": True, "defect_count": 5}],
    )

    depends_on: Selector() = Field(
        description="Reference to the step output this block depends on.",
        examples=["$steps.some_previous_step"],
    )

    @classmethod
    def describe_outputs(cls) -> List[OutputDefinition]:
        return [
            OutputDefinition(
                name="plc_results",
                kind=[LIST_OF_VALUES_KIND],
            ),
        ]

    @classmethod
    def get_execution_engine_compatibility(cls) -> Optional[str]:
        return ">=1.0.0,<2.0.0"

PLCBlockV1

Bases: WorkflowBlock

A PLC communication workflow block using Ethernet/IP and pylogix.

Depending on the selected mode: - 'read': Reads specified tags. - 'write': Writes provided values to specified tags. - 'read_and_write': Reads and writes in one go.

In case of failures, errors are printed to terminal and the corresponding tag entry in the output is set to "ReadFailure" or "WriteFailure".

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/PLCethernetIP/v1.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class PLCBlockV1(WorkflowBlock):
    """A PLC communication workflow block using Ethernet/IP and pylogix.

    Depending on the selected mode:
    - 'read': Reads specified tags.
    - 'write': Writes provided values to specified tags.
    - 'read_and_write': Reads and writes in one go.

    In case of failures, errors are printed to terminal and the corresponding tag entry in the output is set to "ReadFailure" or "WriteFailure".
    """

    @classmethod
    def get_manifest(cls) -> Type[WorkflowBlockManifest]:
        return PLCBlockManifest

    def _read_single_tag(self, comm, tag):
        try:
            response = comm.Read(tag)
            if response.Status == "Success":
                return response.Value
            logger.error(f"Error reading tag '%s': %s", tag, response.Status)
            return "ReadFailure"
        except Exception as e:
            logger.error(f"Unhandled error reading tag '%s': %s", tag, e)
            return "ReadFailure"

    def _write_single_tag(self, comm, tag, value):
        try:
            response = comm.Write(tag, value)
            if response.Status == "Success":
                return "WriteSuccess"
            logger.error(
                "Error writing tag '%s' with value '%s': %s",
                tag,
                value,
                response.Status,
            )
            return "WriteFailure"
        except Exception as e:
            logger.error(f"Unhandled error writing tag '%s': %s", tag, e)
            return "WriteFailure"

    def run(
        self,
        plc_ip: str,
        mode: str,
        tags_to_read: List[str],
        tags_to_write: Dict[str, Union[int, float, str]],
        depends_on: any,
        image: Optional[WorkflowImageData] = None,
        metadata: Optional[VideoMetadata] = None,
    ) -> dict:
        """Run PLC read/write operations using pylogix over Ethernet/IP.

        Args:
            plc_ip (str): PLC IP address.
            mode (str): 'read', 'write', or 'read_and_write'.
            tags_to_read (List[str]): Tags to read if applicable.
            tags_to_write (Dict[str, Union[int, float, str]]): Tags to write if applicable.
            depends_on (any): The step output this block depends on.
            image (Optional[WorkflowImageData]): Not required for this block.
            metadata (Optional[VideoMetadata]): Not required for this block.

        Returns:
            dict: A dictionary with `plc_results` as a list containing one dictionary. That dictionary has 'read' and/or 'write' keys.
        """
        read_results = {}
        write_results = {}

        with pylogix.PLC() as comm:
            comm.IPAddress = plc_ip

            if mode in ["read", "read_and_write"]:
                read_results = {
                    tag: self._read_single_tag(comm, tag) for tag in tags_to_read
                }

            if mode in ["write", "read_and_write"]:
                write_results = {
                    tag: self._write_single_tag(comm, tag, value)
                    for tag, value in tags_to_write.items()
                }

        plc_output = {}
        if read_results:
            plc_output["read"] = read_results
        if write_results:
            plc_output["write"] = write_results

        return {"plc_results": [plc_output]}
Functions
run
run(
    plc_ip,
    mode,
    tags_to_read,
    tags_to_write,
    depends_on,
    image=None,
    metadata=None,
)

Run PLC read/write operations using pylogix over Ethernet/IP.

Parameters:

Name Type Description Default
plc_ip str

PLC IP address.

required
mode str

'read', 'write', or 'read_and_write'.

required
tags_to_read List[str]

Tags to read if applicable.

required
tags_to_write Dict[str, Union[int, float, str]]

Tags to write if applicable.

required
depends_on any

The step output this block depends on.

required
image Optional[WorkflowImageData]

Not required for this block.

None
metadata Optional[VideoMetadata]

Not required for this block.

None

Returns:

Name Type Description
dict dict

A dictionary with plc_results as a list containing one dictionary. That dictionary has 'read' and/or 'write' keys.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/PLCethernetIP/v1.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def run(
    self,
    plc_ip: str,
    mode: str,
    tags_to_read: List[str],
    tags_to_write: Dict[str, Union[int, float, str]],
    depends_on: any,
    image: Optional[WorkflowImageData] = None,
    metadata: Optional[VideoMetadata] = None,
) -> dict:
    """Run PLC read/write operations using pylogix over Ethernet/IP.

    Args:
        plc_ip (str): PLC IP address.
        mode (str): 'read', 'write', or 'read_and_write'.
        tags_to_read (List[str]): Tags to read if applicable.
        tags_to_write (Dict[str, Union[int, float, str]]): Tags to write if applicable.
        depends_on (any): The step output this block depends on.
        image (Optional[WorkflowImageData]): Not required for this block.
        metadata (Optional[VideoMetadata]): Not required for this block.

    Returns:
        dict: A dictionary with `plc_results` as a list containing one dictionary. That dictionary has 'read' and/or 'write' keys.
    """
    read_results = {}
    write_results = {}

    with pylogix.PLC() as comm:
        comm.IPAddress = plc_ip

        if mode in ["read", "read_and_write"]:
            read_results = {
                tag: self._read_single_tag(comm, tag) for tag in tags_to_read
            }

        if mode in ["write", "read_and_write"]:
            write_results = {
                tag: self._write_single_tag(comm, tag, value)
                for tag, value in tags_to_write.items()
            }

    plc_output = {}
    if read_results:
        plc_output["read"] = read_results
    if write_results:
        plc_output["write"] = write_results

    return {"plc_results": [plc_output]}

enterprise/workflows/enterprise_blocks/sinks/microsoft_sql_server

inference.enterprise.workflows.enterprise_blocks.sinks.microsoft_sql_server.v1

Classes

SQLServerConnectionError

Bases: SQLServerError

Exception raised for connection-related errors

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/microsoft_sql_server/v1.py
43
44
45
46
class SQLServerConnectionError(SQLServerError):
    """Exception raised for connection-related errors"""

    pass

SQLServerError

Bases: Exception

Base exception for SQL Server related errors

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/microsoft_sql_server/v1.py
37
38
39
40
class SQLServerError(Exception):
    """Base exception for SQL Server related errors"""

    pass

SQLServerInsertError

Bases: SQLServerError

Exception raised for insert operation errors

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/microsoft_sql_server/v1.py
49
50
51
52
class SQLServerInsertError(SQLServerError):
    """Exception raised for insert operation errors"""

    pass

enterprise/workflows/enterprise_blocks/sinks/opc_writer

inference.enterprise.workflows.enterprise_blocks.sinks.opc_writer.v1

Classes

OPCUAConnectionManager

Thread-safe connection manager for OPC UA clients with connection pooling and circuit breaker pattern.

Maintains a pool of connections keyed by (url, user_name) to avoid creating new connections for every write operation. Uses circuit breaker to fail fast when servers are unreachable.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class OPCUAConnectionManager:
    """
    Thread-safe connection manager for OPC UA clients with connection pooling
    and circuit breaker pattern.

    Maintains a pool of connections keyed by (url, user_name) to avoid creating
    new connections for every write operation. Uses circuit breaker to fail fast
    when servers are unreachable.
    """

    _instance: Optional["OPCUAConnectionManager"] = None
    _lock = threading.Lock()

    # Circuit breaker: how long to wait before trying a failed server again
    CIRCUIT_BREAKER_TIMEOUT_SECONDS = 2.0

    def __new__(cls) -> "OPCUAConnectionManager":
        """Singleton pattern to ensure one connection manager across the application."""
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized:
            return
        self._connections: Dict[str, Client] = {}
        self._connection_locks: Dict[str, threading.Lock] = {}
        self._connection_metadata: Dict[str, dict] = {}
        self._connection_failures: Dict[str, float] = (
            {}
        )  # key -> timestamp of last failure
        self._global_lock = threading.Lock()
        self._tloop: Optional[ThreadLoop] = None
        self._initialized = True
        logger.debug("OPC UA Connection Manager initialized")

    def _get_tloop(self) -> ThreadLoop:
        """Get or create the shared ThreadLoop for all clients."""
        if self._tloop is None or not self._tloop.is_alive():
            logger.debug("OPC UA Connection Manager creating shared ThreadLoop")
            self._tloop = ThreadLoop(timeout=120)
            self._tloop.start()
        return self._tloop

    def _stop_tloop(self) -> None:
        """Stop the shared ThreadLoop if it exists."""
        if self._tloop is not None and self._tloop.is_alive():
            logger.debug("OPC UA Connection Manager stopping shared ThreadLoop")
            try:
                self._tloop.loop.call_soon_threadsafe(self._tloop.loop.stop)
                self._tloop.join(timeout=2.0)
            except Exception as exc:
                logger.debug(f"OPC UA Connection Manager ThreadLoop stop error: {exc}")
            self._tloop = None

    def _get_connection_key(self, url: str, user_name: Optional[str]) -> str:
        """Generate a unique key for connection pooling."""
        return f"{url}|{user_name or ''}"

    def _get_connection_lock(self, key: str) -> threading.Lock:
        """Get or create a lock for a specific connection."""
        with self._global_lock:
            if key not in self._connection_locks:
                self._connection_locks[key] = threading.Lock()
            return self._connection_locks[key]

    def _create_client(
        self,
        url: str,
        user_name: Optional[str],
        password: Optional[str],
        timeout: int,
    ) -> Client:
        """Create and configure a new OPC UA client using the shared ThreadLoop."""
        logger.debug(f"OPC UA Connection Manager creating client for {url}")
        tloop = self._get_tloop()
        client = Client(url=url, tloop=tloop, sync_wrapper_timeout=timeout)
        if user_name and password:
            client.set_user(user_name)
            client.set_password(password)
        return client

    def _connect_with_retry(
        self,
        client: Client,
        url: str,
        max_retries: int = 3,
        base_backoff: float = 1.0,
    ) -> None:
        """
        Connect to OPC UA server with retry logic and exponential backoff.

        Args:
            client: The OPC UA client to connect
            url: Server URL (for logging)
            max_retries: Maximum number of connection attempts
            base_backoff: Base delay between retries (seconds), doubles each retry

        Raises:
            Exception: If all connection attempts fail
        """
        last_exception = None

        for attempt in range(max_retries):
            try:
                logger.debug(
                    f"OPC UA Connection Manager connecting to {url} "
                    f"(attempt {attempt + 1}/{max_retries})"
                )
                client.connect()
                logger.info(
                    f"OPC UA Connection Manager successfully connected to {url}"
                )
                return
            except BadUserAccessDenied as exc:
                # Auth errors should not be retried - they will keep failing
                logger.error(f"OPC UA Connection Manager authentication failed: {exc}")
                raise Exception(f"AUTH ERROR: {exc}")
            except OSError as exc:
                last_exception = exc
                logger.warning(
                    f"OPC UA Connection Manager network error on attempt {attempt + 1}: {exc}"
                )
            except Exception as exc:
                last_exception = exc
                logger.warning(
                    f"OPC UA Connection Manager connection error on attempt {attempt + 1}: "
                    f"{type(exc).__name__}: {exc}"
                )

            # Don't sleep after the last attempt
            if attempt < max_retries - 1:
                backoff_time = base_backoff * (2**attempt)
                logger.debug(
                    f"OPC UA Connection Manager waiting {backoff_time}s before retry"
                )
                time.sleep(backoff_time)

        # All retries exhausted
        logger.error(
            f"OPC UA Connection Manager failed to connect to {url} "
            f"after {max_retries} attempts"
        )
        if isinstance(last_exception, OSError):
            raise Exception(
                f"NETWORK ERROR: Failed to connect after {max_retries} attempts. Last error: {last_exception}"
            )
        raise Exception(
            f"CONNECTION ERROR: Failed to connect after {max_retries} attempts. Last error: {last_exception}"
        )

    def _is_circuit_open(self, key: str) -> bool:
        """
        Check if circuit breaker is open (server recently failed).
        Returns True if we should NOT attempt connection (fail fast).
        """
        if key not in self._connection_failures:
            return False

        time_since_failure = time.time() - self._connection_failures[key]
        if time_since_failure < self.CIRCUIT_BREAKER_TIMEOUT_SECONDS:
            return True

        # Timeout expired, clear the failure record
        del self._connection_failures[key]
        return False

    def _record_failure(self, key: str) -> None:
        """Record a connection failure for circuit breaker."""
        self._connection_failures[key] = time.time()

    def _clear_failure(self, key: str) -> None:
        """Clear failure record after successful connection."""
        if key in self._connection_failures:
            del self._connection_failures[key]

    def get_connection(
        self,
        url: str,
        user_name: Optional[str],
        password: Optional[str],
        timeout: int,
        max_retries: int = 1,
        base_backoff: float = 0.0,
    ) -> Client:
        """
        Get a connection from the pool or create a new one.

        This method is thread-safe and will reuse existing healthy connections.
        Uses circuit breaker pattern to fail fast for recently failed servers.

        Args:
            url: OPC UA server URL
            user_name: Optional username for authentication
            password: Optional password for authentication
            timeout: Connection timeout in seconds
            max_retries: Maximum number of connection attempts (default 1)
            base_backoff: Base delay between retries (default 0)

        Returns:
            A connected OPC UA client

        Raises:
            Exception: If connection fails or circuit breaker is open
        """
        key = self._get_connection_key(url, user_name)
        lock = self._get_connection_lock(key)

        with lock:
            # Circuit breaker: fail fast if server recently failed
            if self._is_circuit_open(key):
                logger.debug(
                    f"OPC UA Connection Manager circuit breaker open for {url}, "
                    f"failing fast (will retry in {self.CIRCUIT_BREAKER_TIMEOUT_SECONDS}s)"
                )
                raise Exception(
                    f"CIRCUIT OPEN: Server {url} recently failed, skipping connection attempt. "
                    f"Will retry after {self.CIRCUIT_BREAKER_TIMEOUT_SECONDS}s."
                )

            # Check if we have an existing connection
            if key in self._connections:
                logger.debug(f"OPC UA Connection Manager reusing connection for {url}")
                return self._connections[key]

            # Create new connection
            try:
                client = self._create_client(url, user_name, password, timeout)
                self._connect_with_retry(client, url, max_retries, base_backoff)

                # Success - clear any failure record and store in pool
                self._clear_failure(key)
                self._connections[key] = client
                self._connection_metadata[key] = {
                    "url": url,
                    "user_name": user_name,
                    "password": password,
                    "timeout": timeout,
                    "connected_at": datetime.now(),
                }

                return client
            except Exception as exc:
                # Record failure for circuit breaker
                self._record_failure(key)
                raise

    def _safe_disconnect(self, client: Client) -> None:
        """Safely disconnect a client, swallowing any errors."""
        try:
            logger.debug("OPC UA Connection Manager disconnecting client")
            client.disconnect()
        except Exception as exc:
            logger.debug(
                f"OPC UA Connection Manager disconnect error (non-fatal): {exc}"
            )

    def release_connection(
        self, url: str, user_name: Optional[str], force_close: bool = False
    ) -> None:
        """
        Release a connection back to the pool.

        By default, connections are kept alive for reuse. Set force_close=True
        to immediately close the connection.

        Args:
            url: OPC UA server URL
            user_name: Optional username used for the connection
            force_close: If True, close the connection instead of keeping it
        """
        if not force_close:
            # Connection stays in pool for reuse
            return

        key = self._get_connection_key(url, user_name)
        lock = self._get_connection_lock(key)

        with lock:
            if key in self._connections:
                self._safe_disconnect(self._connections[key])
                del self._connections[key]
                if key in self._connection_metadata:
                    del self._connection_metadata[key]
                logger.debug(f"OPC UA Connection Manager closed connection for {url}")

    def invalidate_connection(self, url: str, user_name: Optional[str]) -> None:
        """
        Invalidate a connection, forcing it to be recreated on next use.

        Call this when a connection error occurs during an operation to ensure
        the next operation gets a fresh connection.

        Args:
            url: OPC UA server URL
            user_name: Optional username used for the connection
        """
        key = self._get_connection_key(url, user_name)
        lock = self._get_connection_lock(key)

        with lock:
            if key in self._connections:
                self._safe_disconnect(self._connections[key])
                del self._connections[key]
                if key in self._connection_metadata:
                    del self._connection_metadata[key]
                logger.debug(
                    f"OPC UA Connection Manager invalidated connection for {url}"
                )

    def close_all(self) -> None:
        """Close all connections in the pool and stop the shared ThreadLoop."""
        with self._global_lock:
            for key, client in list(self._connections.items()):
                self._safe_disconnect(client)
            self._connections.clear()
            self._connection_metadata.clear()
            self._stop_tloop()
            logger.info("OPC UA Connection Manager closed all connections")

    def get_pool_stats(self) -> dict:
        """Get statistics about the connection pool."""
        with self._global_lock:
            return {
                "total_connections": len(self._connections),
                "connections": [
                    {
                        "url": meta["url"],
                        "user_name": meta["user_name"],
                        "connected_at": meta["connected_at"].isoformat(),
                    }
                    for meta in self._connection_metadata.values()
                ],
            }
Functions
__new__
__new__()

Singleton pattern to ensure one connection manager across the application.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
35
36
37
38
39
40
41
42
def __new__(cls) -> "OPCUAConnectionManager":
    """Singleton pattern to ensure one connection manager across the application."""
    if cls._instance is None:
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._initialized = False
    return cls._instance
close_all
close_all()

Close all connections in the pool and stop the shared ThreadLoop.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
332
333
334
335
336
337
338
339
340
def close_all(self) -> None:
    """Close all connections in the pool and stop the shared ThreadLoop."""
    with self._global_lock:
        for key, client in list(self._connections.items()):
            self._safe_disconnect(client)
        self._connections.clear()
        self._connection_metadata.clear()
        self._stop_tloop()
        logger.info("OPC UA Connection Manager closed all connections")
get_connection
get_connection(
    url,
    user_name,
    password,
    timeout,
    max_retries=1,
    base_backoff=0.0,
)

Get a connection from the pool or create a new one.

This method is thread-safe and will reuse existing healthy connections. Uses circuit breaker pattern to fail fast for recently failed servers.

Parameters:

Name Type Description Default
url str

OPC UA server URL

required
user_name Optional[str]

Optional username for authentication

required
password Optional[str]

Optional password for authentication

required
timeout int

Connection timeout in seconds

required
max_retries int

Maximum number of connection attempts (default 1)

1
base_backoff float

Base delay between retries (default 0)

0.0

Returns:

Type Description
Client

A connected OPC UA client

Raises:

Type Description
Exception

If connection fails or circuit breaker is open

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def get_connection(
    self,
    url: str,
    user_name: Optional[str],
    password: Optional[str],
    timeout: int,
    max_retries: int = 1,
    base_backoff: float = 0.0,
) -> Client:
    """
    Get a connection from the pool or create a new one.

    This method is thread-safe and will reuse existing healthy connections.
    Uses circuit breaker pattern to fail fast for recently failed servers.

    Args:
        url: OPC UA server URL
        user_name: Optional username for authentication
        password: Optional password for authentication
        timeout: Connection timeout in seconds
        max_retries: Maximum number of connection attempts (default 1)
        base_backoff: Base delay between retries (default 0)

    Returns:
        A connected OPC UA client

    Raises:
        Exception: If connection fails or circuit breaker is open
    """
    key = self._get_connection_key(url, user_name)
    lock = self._get_connection_lock(key)

    with lock:
        # Circuit breaker: fail fast if server recently failed
        if self._is_circuit_open(key):
            logger.debug(
                f"OPC UA Connection Manager circuit breaker open for {url}, "
                f"failing fast (will retry in {self.CIRCUIT_BREAKER_TIMEOUT_SECONDS}s)"
            )
            raise Exception(
                f"CIRCUIT OPEN: Server {url} recently failed, skipping connection attempt. "
                f"Will retry after {self.CIRCUIT_BREAKER_TIMEOUT_SECONDS}s."
            )

        # Check if we have an existing connection
        if key in self._connections:
            logger.debug(f"OPC UA Connection Manager reusing connection for {url}")
            return self._connections[key]

        # Create new connection
        try:
            client = self._create_client(url, user_name, password, timeout)
            self._connect_with_retry(client, url, max_retries, base_backoff)

            # Success - clear any failure record and store in pool
            self._clear_failure(key)
            self._connections[key] = client
            self._connection_metadata[key] = {
                "url": url,
                "user_name": user_name,
                "password": password,
                "timeout": timeout,
                "connected_at": datetime.now(),
            }

            return client
        except Exception as exc:
            # Record failure for circuit breaker
            self._record_failure(key)
            raise
get_pool_stats
get_pool_stats()

Get statistics about the connection pool.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def get_pool_stats(self) -> dict:
    """Get statistics about the connection pool."""
    with self._global_lock:
        return {
            "total_connections": len(self._connections),
            "connections": [
                {
                    "url": meta["url"],
                    "user_name": meta["user_name"],
                    "connected_at": meta["connected_at"].isoformat(),
                }
                for meta in self._connection_metadata.values()
            ],
        }
invalidate_connection
invalidate_connection(url, user_name)

Invalidate a connection, forcing it to be recreated on next use.

Call this when a connection error occurs during an operation to ensure the next operation gets a fresh connection.

Parameters:

Name Type Description Default
url str

OPC UA server URL

required
user_name Optional[str]

Optional username used for the connection

required
Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def invalidate_connection(self, url: str, user_name: Optional[str]) -> None:
    """
    Invalidate a connection, forcing it to be recreated on next use.

    Call this when a connection error occurs during an operation to ensure
    the next operation gets a fresh connection.

    Args:
        url: OPC UA server URL
        user_name: Optional username used for the connection
    """
    key = self._get_connection_key(url, user_name)
    lock = self._get_connection_lock(key)

    with lock:
        if key in self._connections:
            self._safe_disconnect(self._connections[key])
            del self._connections[key]
            if key in self._connection_metadata:
                del self._connection_metadata[key]
            logger.debug(
                f"OPC UA Connection Manager invalidated connection for {url}"
            )
release_connection
release_connection(url, user_name, force_close=False)

Release a connection back to the pool.

By default, connections are kept alive for reuse. Set force_close=True to immediately close the connection.

Parameters:

Name Type Description Default
url str

OPC UA server URL

required
user_name Optional[str]

Optional username used for the connection

required
force_close bool

If True, close the connection instead of keeping it

False
Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def release_connection(
    self, url: str, user_name: Optional[str], force_close: bool = False
) -> None:
    """
    Release a connection back to the pool.

    By default, connections are kept alive for reuse. Set force_close=True
    to immediately close the connection.

    Args:
        url: OPC UA server URL
        user_name: Optional username used for the connection
        force_close: If True, close the connection instead of keeping it
    """
    if not force_close:
        # Connection stays in pool for reuse
        return

    key = self._get_connection_key(url, user_name)
    lock = self._get_connection_lock(key)

    with lock:
        if key in self._connections:
            self._safe_disconnect(self._connections[key])
            del self._connections[key]
            if key in self._connection_metadata:
                del self._connection_metadata[key]
            logger.debug(f"OPC UA Connection Manager closed connection for {url}")

UnsupportedTypeError

Bases: Exception

Raised when an unsupported value type is specified

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
370
371
372
373
class UnsupportedTypeError(Exception):
    """Raised when an unsupported value type is specified"""

    pass

Functions

get_available_namespaces

get_available_namespaces(client)

Get list of available namespaces from OPC UA server. Returns empty list if unable to fetch namespaces.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
778
779
780
781
782
783
784
785
786
787
788
789
790
def get_available_namespaces(client: Client) -> List[str]:
    """
    Get list of available namespaces from OPC UA server.
    Returns empty list if unable to fetch namespaces.
    """
    try:
        get_namespace_array = sync_async_client_method(AsyncClient.get_namespace_array)(
            client
        )
        return get_namespace_array()
    except Exception as exc:
        logger.info(f"Failed to get namespace array (non-fatal): {exc}")
        return ["<unable to fetch namespaces>"]

get_connection_manager

get_connection_manager()

Get the global OPC UA connection manager instance.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
362
363
364
365
366
367
def get_connection_manager() -> OPCUAConnectionManager:
    """Get the global OPC UA connection manager instance."""
    global _connection_manager
    if _connection_manager is None:
        _connection_manager = OPCUAConnectionManager()
    return _connection_manager

get_node_data_type

get_node_data_type(var)

Get the data type of an OPC UA node. Returns a string representation of the type, or "Unknown" if unable to read.

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
802
803
804
805
806
807
808
809
810
811
def get_node_data_type(var) -> str:
    """
    Get the data type of an OPC UA node.
    Returns a string representation of the type, or "Unknown" if unable to read.
    """
    try:
        return str(var.read_data_type_as_variant_type())
    except Exception as exc:
        logger.info(f"Unable to read node data type: {exc}")
        return "Unknown"

opc_connect_and_write_value

opc_connect_and_write_value(
    url,
    namespace,
    user_name,
    password,
    object_name,
    variable_name,
    value,
    timeout,
    node_lookup_mode="hierarchical",
    value_type="String",
    max_retries=1,
    retry_backoff_seconds=0.0,
)

Connect to OPC UA server and write a value using connection pooling.

Uses the connection manager to reuse existing connections. If no connection exists, attempts to create one. Fails fast on connection errors to avoid blocking the pipeline.

Parameters:

Name Type Description Default
url str

OPC UA server URL

required
namespace str

Namespace URI or index

required
user_name Optional[str]

Optional username for authentication

required
password Optional[str]

Optional password for authentication

required
object_name str

Target object path

required
variable_name str

Variable to write

required
value Union[bool, float, int, str]

Value to write

required
timeout int

Connection timeout in seconds

required
node_lookup_mode Literal['hierarchical', 'direct']

Path lookup strategy ('hierarchical' or 'direct')

'hierarchical'
value_type str

OPC UA data type for the value

'String'
max_retries int

Maximum number of connection attempts (default 1 = no retries)

1
retry_backoff_seconds float

Base delay between retries (default 0 = no delay)

0.0

Returns:

Type Description
Tuple[bool, str]

Tuple of (error_status, message)

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
def opc_connect_and_write_value(
    url: str,
    namespace: str,
    user_name: Optional[str],
    password: Optional[str],
    object_name: str,
    variable_name: str,
    value: Union[bool, float, int, str],
    timeout: int,
    node_lookup_mode: Literal["hierarchical", "direct"] = "hierarchical",
    value_type: str = "String",
    max_retries: int = 1,
    retry_backoff_seconds: float = 0.0,
) -> Tuple[bool, str]:
    """
    Connect to OPC UA server and write a value using connection pooling.

    Uses the connection manager to reuse existing connections. If no connection
    exists, attempts to create one. Fails fast on connection errors to avoid
    blocking the pipeline.

    Args:
        url: OPC UA server URL
        namespace: Namespace URI or index
        user_name: Optional username for authentication
        password: Optional password for authentication
        object_name: Target object path
        variable_name: Variable to write
        value: Value to write
        timeout: Connection timeout in seconds
        node_lookup_mode: Path lookup strategy ('hierarchical' or 'direct')
        value_type: OPC UA data type for the value
        max_retries: Maximum number of connection attempts (default 1 = no retries)
        retry_backoff_seconds: Base delay between retries (default 0 = no delay)

    Returns:
        Tuple of (error_status, message)
    """
    logger.debug(
        f"OPC Writer attempting to write value={value} to {url}/{object_name}/{variable_name}"
    )

    connection_manager = get_connection_manager()

    try:
        # Get connection from pool (will create new if needed)
        client = connection_manager.get_connection(
            url=url,
            user_name=user_name,
            password=password,
            timeout=timeout,
            max_retries=max_retries,
            base_backoff=retry_backoff_seconds,
        )

        # Perform the write operation
        _opc_write_value(
            client=client,
            namespace=namespace,
            object_name=object_name,
            variable_name=variable_name,
            value=value,
            node_lookup_mode=node_lookup_mode,
            value_type=value_type,
        )

        logger.debug(
            f"OPC Writer successfully wrote value to {url}/{object_name}/{variable_name}"
        )
        return False, "Value set successfully"

    except Exception as exc:
        is_user_config_error = isinstance(exc, USER_CONFIG_ERROR_TYPES)

        # Check the exception chain for wrapped errors
        if not is_user_config_error and hasattr(exc, "__cause__") and exc.__cause__:
            is_user_config_error = isinstance(exc.__cause__, USER_CONFIG_ERROR_TYPES)

        if not is_user_config_error:
            logger.warning(
                f"OPC Writer error (invalidating connection): {type(exc).__name__}: {exc}"
            )
            connection_manager.invalidate_connection(url, user_name)
        else:
            # User configuration errors - connection is fine, just log the error
            logger.error(f"OPC Writer configuration error: {type(exc).__name__}: {exc}")

        return (
            True,
            f"Failed to write {value} to {object_name}:{variable_name} in {url}. Error: {exc}",
        )

safe_disconnect

safe_disconnect(client)

Safely disconnect from OPC UA server, swallowing any errors

Source code in inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py
793
794
795
796
797
798
799
def safe_disconnect(client: Client) -> None:
    """Safely disconnect from OPC UA server, swallowing any errors"""
    try:
        logger.debug("OPC Writer disconnecting from server")
        client.disconnect()
    except Exception as exc:
        logger.debug(f"OPC Writer disconnect error (non-fatal): {exc}")

models/clip

inference.models.clip.clip_inference_models

Classes

InferenceModelsClipAdapter

Bases: Model

Roboflow ONNX ClipModel model.

This class is responsible for handling the ONNX ClipModel model, including loading the model, preprocessing the input, and performing inference.

Attributes:

Name Type Description
visual_onnx_session InferenceSession

ONNX Runtime session for visual inference.

textual_onnx_session InferenceSession

ONNX Runtime session for textual inference.

resolution int

The resolution of the input image.

clip_preprocess function

Function to preprocess the image.

Source code in inference/models/clip/clip_inference_models.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
class InferenceModelsClipAdapter(Model):
    """Roboflow ONNX ClipModel model.

    This class is responsible for handling the ONNX ClipModel model, including
    loading the model, preprocessing the input, and performing inference.

    Attributes:
        visual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for visual inference.
        textual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for textual inference.
        resolution (int): The resolution of the input image.
        clip_preprocess (function): Function to preprocess the image.
    """

    def __init__(
        self,
        model_id: str = CLIP_MODEL_ID,
        api_key: str = None,
        **kwargs,
    ):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY
        self.task_type = "embedding"
        weights_provider_extra_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: Union[ClipOnnx, ClipTorch] = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=weights_provider_extra_headers,
            backend=backend,
            **kwargs,
        )

    def compare(
        self,
        subject: Any,
        prompt: Any,
        subject_type: str = "image",
        prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
        **kwargs,
    ) -> Union[List[float], Dict[str, float]]:
        """
        Compares the subject with the prompt to calculate similarity scores.

        Args:
            subject (Any): The subject data to be compared. Can be either an image or text.
            prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
            subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
            prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
            **kwargs: Additional keyword arguments.

        Returns:
            Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

        Raises:
            ValueError: If subject_type or prompt_type is neither "image" nor "text".
            ValueError: If the number of prompts exceeds the maximum batch size.
        """

        if subject_type == "image":
            subject_embeddings = self.embed_image(subject)
        elif subject_type == "text":
            subject_embeddings = self.embed_text(subject)
        else:
            raise ValueError(
                "subject_type must be either 'image' or 'text', but got {request.subject_type}"
            )

        if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
            prompt_keys = prompt.keys()
            prompt = [prompt[k] for k in prompt_keys]
            prompt_obj = "dict"
        else:
            prompt = prompt
            if not isinstance(prompt, list):
                prompt = [prompt]
            prompt_obj = "list"

        if len(prompt) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
            )

        if prompt_type == "image":
            prompt_embeddings = self.embed_image(prompt)
        elif prompt_type == "text":
            prompt_embeddings = self.embed_text(prompt)
        else:
            raise ValueError(
                "prompt_type must be either 'image' or 'text', but got {request.prompt_type}"
            )

        similarities = [
            cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
        ]

        if prompt_obj == "dict":
            similarities = dict(zip(prompt_keys, similarities))

        return similarities

    def make_compare_response(
        self, similarities: Union[List[float], Dict[str, float]]
    ) -> ClipCompareResponse:
        """
        Creates a ClipCompareResponse object from the provided similarity data.

        Args:
            similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores.

        Returns:
            ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores.

        Example:
            Assuming `ClipCompareResponse` expects a dictionary of string-float pairs:

            >>> make_compare_response({"image1": 0.98, "image2": 0.76})
            ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})
        """
        response = ClipCompareResponse(similarity=similarities)
        return response

    def embed_image(
        self,
        image: Any,
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds an image or a list of images using the Clip model.

        Args:
            image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the image(s) as a numpy array.

        Raises:
            ValueError: If the number of images in the list exceeds the maximum batch size.

        Notes:
            The function measures performance using perf_counter and also has support for ONNX session to get embeddings.
        """
        t1 = perf_counter()

        if isinstance(image, list):
            if len(image) > CLIP_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
                )
            imgs = [self.preproc_image(i) for i in image]
            img_in = np.concatenate(imgs, axis=0)
        else:
            img_in = self.preproc_image(image)
        embeddings = self._model.embed_images(images=img_in)
        return embeddings.cpu().numpy()

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        embeddings = self._model.embed_images(images=img_in)
        return (embeddings.cpu().numpy(),)

    def make_embed_image_response(
        self, embeddings: np.ndarray
    ) -> ClipEmbeddingResponse:
        """
        Converts the given embeddings into a ClipEmbeddingResponse object.

        Args:
            embeddings (np.ndarray): A numpy array containing the embeddings for an image or images.

        Returns:
            ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

        Example:
            >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
            >>> make_embed_image_response(embeddings_array)
            ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
        """
        response = ClipEmbeddingResponse(embeddings=embeddings.tolist())

        return response

    def embed_text(
        self,
        text: Union[str, List[str]],
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds a text or a list of texts using the Clip model.

        Args:
            text (Union[str, List[str]]): The text string or list of text strings to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the text or texts as a numpy array.

        Raises:
            ValueError: If the number of text strings in the list exceeds the maximum batch size.

        Notes:
            The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.
        """
        if isinstance(text, list):
            texts = text
        else:
            texts = [text]
        embeddings = self._model.embed_text(texts=texts)
        return embeddings.cpu().numpy()

    def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse:
        """
        Converts the given text embeddings into a ClipEmbeddingResponse object.

        Args:
            embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts.

        Returns:
            ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

        Example:
            >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
            >>> make_embed_text_response(embeddings_array)
            ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
        """
        response = ClipEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def infer_from_request(
        self, request: ClipInferenceRequest
    ) -> ClipEmbeddingResponse:
        """Routes the request to the appropriate inference function.

        Args:
            request (ClipInferenceRequest): The request object containing the inference details.

        Returns:
            ClipEmbeddingResponse: The response object containing the embeddings.
        """
        t1 = perf_counter()
        if isinstance(request, ClipImageEmbeddingRequest):
            infer_func = self.embed_image
            make_response_func = self.make_embed_image_response
        elif isinstance(request, ClipTextEmbeddingRequest):
            infer_func = self.embed_text
            make_response_func = self.make_embed_text_response
        elif isinstance(request, ClipCompareRequest):
            infer_func = self.compare
            make_response_func = self.make_compare_response
        else:
            raise ValueError(
                f"Request type {type(request)} is not a valid ClipInferenceRequest"
            )
        data = infer_func(**request.dict())
        response = make_response_func(data)
        response.time = perf_counter() - t1
        return response

    def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse:
        return [self.make_embed_image_response(embeddings)]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return [self.make_embed_image_response(predictions[0])]

    def infer(self, image: Any, **kwargs) -> Any:
        """Embeds an image
        - image:
            can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        """
        return super().infer(image, **kwargs)

    def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
        """Preprocesses an inference request image.

        Args:
            image (InferenceRequestImage): The object containing information necessary to load the image for inference.

        Returns:
            np.ndarray: A numpy array of the preprocessed image pixel data.
        """
        return load_image_bgr(image)

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        return self.preproc_image(image), PreprocessReturnMetadata({})
Functions
compare
compare(
    subject,
    prompt,
    subject_type="image",
    prompt_type="text",
    **kwargs
)

Compares the subject with the prompt to calculate similarity scores.

Parameters:

Name Type Description Default
subject Any

The subject data to be compared. Can be either an image or text.

required
prompt Any

The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.

required
subject_type str

Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".

'image'
prompt_type Union[str, List[str], Dict[str, Any]]

Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".

'text'
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[List[float], Dict[str, float]]

Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

Raises:

Type Description
ValueError

If subject_type or prompt_type is neither "image" nor "text".

ValueError

If the number of prompts exceeds the maximum batch size.

Source code in inference/models/clip/clip_inference_models.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def compare(
    self,
    subject: Any,
    prompt: Any,
    subject_type: str = "image",
    prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
    **kwargs,
) -> Union[List[float], Dict[str, float]]:
    """
    Compares the subject with the prompt to calculate similarity scores.

    Args:
        subject (Any): The subject data to be compared. Can be either an image or text.
        prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
        subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
        prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
        **kwargs: Additional keyword arguments.

    Returns:
        Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

    Raises:
        ValueError: If subject_type or prompt_type is neither "image" nor "text".
        ValueError: If the number of prompts exceeds the maximum batch size.
    """

    if subject_type == "image":
        subject_embeddings = self.embed_image(subject)
    elif subject_type == "text":
        subject_embeddings = self.embed_text(subject)
    else:
        raise ValueError(
            "subject_type must be either 'image' or 'text', but got {request.subject_type}"
        )

    if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
        prompt_keys = prompt.keys()
        prompt = [prompt[k] for k in prompt_keys]
        prompt_obj = "dict"
    else:
        prompt = prompt
        if not isinstance(prompt, list):
            prompt = [prompt]
        prompt_obj = "list"

    if len(prompt) > CLIP_MAX_BATCH_SIZE:
        raise ValueError(
            f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
        )

    if prompt_type == "image":
        prompt_embeddings = self.embed_image(prompt)
    elif prompt_type == "text":
        prompt_embeddings = self.embed_text(prompt)
    else:
        raise ValueError(
            "prompt_type must be either 'image' or 'text', but got {request.prompt_type}"
        )

    similarities = [
        cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
    ]

    if prompt_obj == "dict":
        similarities = dict(zip(prompt_keys, similarities))

    return similarities
embed_image
embed_image(image, **kwargs)

Embeds an image or a list of images using the Clip model.

Parameters:

Name Type Description Default
image Any

The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the image(s) as a numpy array.

Raises:

Type Description
ValueError

If the number of images in the list exceeds the maximum batch size.

Notes

The function measures performance using perf_counter and also has support for ONNX session to get embeddings.

Source code in inference/models/clip/clip_inference_models.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def embed_image(
    self,
    image: Any,
    **kwargs,
) -> np.ndarray:
    """
    Embeds an image or a list of images using the Clip model.

    Args:
        image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the image(s) as a numpy array.

    Raises:
        ValueError: If the number of images in the list exceeds the maximum batch size.

    Notes:
        The function measures performance using perf_counter and also has support for ONNX session to get embeddings.
    """
    t1 = perf_counter()

    if isinstance(image, list):
        if len(image) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
            )
        imgs = [self.preproc_image(i) for i in image]
        img_in = np.concatenate(imgs, axis=0)
    else:
        img_in = self.preproc_image(image)
    embeddings = self._model.embed_images(images=img_in)
    return embeddings.cpu().numpy()
embed_text
embed_text(text, **kwargs)

Embeds a text or a list of texts using the Clip model.

Parameters:

Name Type Description Default
text Union[str, List[str]]

The text string or list of text strings to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the text or texts as a numpy array.

Raises:

Type Description
ValueError

If the number of text strings in the list exceeds the maximum batch size.

Notes

The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.

Source code in inference/models/clip/clip_inference_models.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def embed_text(
    self,
    text: Union[str, List[str]],
    **kwargs,
) -> np.ndarray:
    """
    Embeds a text or a list of texts using the Clip model.

    Args:
        text (Union[str, List[str]]): The text string or list of text strings to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the text or texts as a numpy array.

    Raises:
        ValueError: If the number of text strings in the list exceeds the maximum batch size.

    Notes:
        The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.
    """
    if isinstance(text, list):
        texts = text
    else:
        texts = [text]
    embeddings = self._model.embed_text(texts=texts)
    return embeddings.cpu().numpy()
infer
infer(image, **kwargs)

Embeds an image - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Source code in inference/models/clip/clip_inference_models.py
318
319
320
321
322
323
def infer(self, image: Any, **kwargs) -> Any:
    """Embeds an image
    - image:
        can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
    """
    return super().infer(image, **kwargs)
infer_from_request
infer_from_request(request)

Routes the request to the appropriate inference function.

Parameters:

Name Type Description Default
request ClipInferenceRequest

The request object containing the inference details.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

The response object containing the embeddings.

Source code in inference/models/clip/clip_inference_models.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def infer_from_request(
    self, request: ClipInferenceRequest
) -> ClipEmbeddingResponse:
    """Routes the request to the appropriate inference function.

    Args:
        request (ClipInferenceRequest): The request object containing the inference details.

    Returns:
        ClipEmbeddingResponse: The response object containing the embeddings.
    """
    t1 = perf_counter()
    if isinstance(request, ClipImageEmbeddingRequest):
        infer_func = self.embed_image
        make_response_func = self.make_embed_image_response
    elif isinstance(request, ClipTextEmbeddingRequest):
        infer_func = self.embed_text
        make_response_func = self.make_embed_text_response
    elif isinstance(request, ClipCompareRequest):
        infer_func = self.compare
        make_response_func = self.make_compare_response
    else:
        raise ValueError(
            f"Request type {type(request)} is not a valid ClipInferenceRequest"
        )
    data = infer_func(**request.dict())
    response = make_response_func(data)
    response.time = perf_counter() - t1
    return response
make_compare_response
make_compare_response(similarities)

Creates a ClipCompareResponse object from the provided similarity data.

Parameters:

Name Type Description Default
similarities Union[List[float], Dict[str, float]]

A list or dictionary containing similarity scores.

required

Returns:

Name Type Description
ClipCompareResponse ClipCompareResponse

An instance of the ClipCompareResponse with the given similarity scores.

Example

Assuming ClipCompareResponse expects a dictionary of string-float pairs:

make_compare_response({"image1": 0.98, "image2": 0.76}) ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})

Source code in inference/models/clip/clip_inference_models.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def make_compare_response(
    self, similarities: Union[List[float], Dict[str, float]]
) -> ClipCompareResponse:
    """
    Creates a ClipCompareResponse object from the provided similarity data.

    Args:
        similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores.

    Returns:
        ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores.

    Example:
        Assuming `ClipCompareResponse` expects a dictionary of string-float pairs:

        >>> make_compare_response({"image1": 0.98, "image2": 0.76})
        ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})
    """
    response = ClipCompareResponse(similarity=similarities)
    return response
make_embed_image_response
make_embed_image_response(embeddings)

Converts the given embeddings into a ClipEmbeddingResponse object.

Parameters:

Name Type Description Default
embeddings ndarray

A numpy array containing the embeddings for an image or images.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

Example

embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) make_embed_image_response(embeddings_array) ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])

Source code in inference/models/clip/clip_inference_models.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def make_embed_image_response(
    self, embeddings: np.ndarray
) -> ClipEmbeddingResponse:
    """
    Converts the given embeddings into a ClipEmbeddingResponse object.

    Args:
        embeddings (np.ndarray): A numpy array containing the embeddings for an image or images.

    Returns:
        ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

    Example:
        >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
        >>> make_embed_image_response(embeddings_array)
        ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
    """
    response = ClipEmbeddingResponse(embeddings=embeddings.tolist())

    return response
make_embed_text_response
make_embed_text_response(embeddings)

Converts the given text embeddings into a ClipEmbeddingResponse object.

Parameters:

Name Type Description Default
embeddings ndarray

A numpy array containing the embeddings for a text or texts.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

Example

embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) make_embed_text_response(embeddings_array) ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])

Source code in inference/models/clip/clip_inference_models.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse:
    """
    Converts the given text embeddings into a ClipEmbeddingResponse object.

    Args:
        embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts.

    Returns:
        ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

    Example:
        >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
        >>> make_embed_text_response(embeddings_array)
        ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
    """
    response = ClipEmbeddingResponse(embeddings=embeddings.tolist())
    return response
preproc_image
preproc_image(image)

Preprocesses an inference request image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The object containing information necessary to load the image for inference.

required

Returns:

Type Description
ndarray

np.ndarray: A numpy array of the preprocessed image pixel data.

Source code in inference/models/clip/clip_inference_models.py
325
326
327
328
329
330
331
332
333
334
def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
    """Preprocesses an inference request image.

    Args:
        image (InferenceRequestImage): The object containing information necessary to load the image for inference.

    Returns:
        np.ndarray: A numpy array of the preprocessed image pixel data.
    """
    return load_image_bgr(image)

Functions

inference.models.clip.clip_model

Classes

Clip

Bases: OnnxRoboflowCoreModel

Roboflow ONNX ClipModel model.

This class is responsible for handling the ONNX ClipModel model, including loading the model, preprocessing the input, and performing inference.

Attributes:

Name Type Description
visual_onnx_session InferenceSession

ONNX Runtime session for visual inference.

textual_onnx_session InferenceSession

ONNX Runtime session for textual inference.

resolution int

The resolution of the input image.

clip_preprocess function

Function to preprocess the image.

Source code in inference/models/clip/clip_model.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
class Clip(OnnxRoboflowCoreModel):
    """Roboflow ONNX ClipModel model.

    This class is responsible for handling the ONNX ClipModel model, including
    loading the model, preprocessing the input, and performing inference.

    Attributes:
        visual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for visual inference.
        textual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for textual inference.
        resolution (int): The resolution of the input image.
        clip_preprocess (function): Function to preprocess the image.
    """

    def __init__(
        self,
        *args,
        model_id: str = CLIP_MODEL_ID,
        onnxruntime_execution_providers: List[
            str
        ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS),
        **kwargs,
    ):
        """Initializes the Clip with the given arguments and keyword arguments."""
        self.onnxruntime_execution_providers = onnxruntime_execution_providers
        t1 = perf_counter()
        super().__init__(*args, model_id=model_id, **kwargs)
        # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
        self.log("Creating inference sessions")
        self.visual_onnx_session = onnxruntime.InferenceSession(
            self.cache_file("visual.onnx"),
            providers=self.onnxruntime_execution_providers,
        )
        self._visual_session_lock = Lock()
        self.textual_onnx_session = onnxruntime.InferenceSession(
            self.cache_file("textual.onnx"),
            providers=self.onnxruntime_execution_providers,
        )
        self._textual_session_lock = Lock()
        if REQUIRED_ONNX_PROVIDERS:
            available_providers = onnxruntime.get_available_providers()
            for provider in REQUIRED_ONNX_PROVIDERS:
                if provider not in available_providers:
                    raise OnnxProviderNotAvailable(
                        f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device."
                    )

        self.resolution = self.visual_onnx_session.get_inputs()[0].shape[2]

        self.clip_preprocess = clip.clip._transform(self.resolution)
        self.log(f"CLIP model loaded in {perf_counter() - t1:.2f} seconds")
        self.task_type = "embedding"

    def compare(
        self,
        subject: Any,
        prompt: Any,
        subject_type: str = "image",
        prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
        **kwargs,
    ) -> Union[List[float], Dict[str, float]]:
        """
        Compares the subject with the prompt to calculate similarity scores.

        Args:
            subject (Any): The subject data to be compared. Can be either an image or text.
            prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
            subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
            prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
            **kwargs: Additional keyword arguments.

        Returns:
            Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

        Raises:
            ValueError: If subject_type or prompt_type is neither "image" nor "text".
            ValueError: If the number of prompts exceeds the maximum batch size.
        """

        if subject_type == "image":
            subject_embeddings = self.embed_image(subject)
        elif subject_type == "text":
            subject_embeddings = self.embed_text(subject)
        else:
            raise ValueError(
                "subject_type must be either 'image' or 'text', but got {request.subject_type}"
            )

        if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
            prompt_keys = prompt.keys()
            prompt = [prompt[k] for k in prompt_keys]
            prompt_obj = "dict"
        else:
            prompt = prompt
            if not isinstance(prompt, list):
                prompt = [prompt]
            prompt_obj = "list"

        if len(prompt) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
            )

        if prompt_type == "image":
            prompt_embeddings = self.embed_image(prompt)
        elif prompt_type == "text":
            prompt_embeddings = self.embed_text(prompt)
        else:
            raise ValueError(
                "prompt_type must be either 'image' or 'text', but got {request.prompt_type}"
            )

        similarities = [
            cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
        ]

        if prompt_obj == "dict":
            similarities = dict(zip(prompt_keys, similarities))

        return similarities

    def make_compare_response(
        self, similarities: Union[List[float], Dict[str, float]]
    ) -> ClipCompareResponse:
        """
        Creates a ClipCompareResponse object from the provided similarity data.

        Args:
            similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores.

        Returns:
            ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores.

        Example:
            Assuming `ClipCompareResponse` expects a dictionary of string-float pairs:

            >>> make_compare_response({"image1": 0.98, "image2": 0.76})
            ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})
        """
        response = ClipCompareResponse(similarity=similarities)
        return response

    def embed_image(
        self,
        image: Any,
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds an image or a list of images using the Clip model.

        Args:
            image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the image(s) as a numpy array.

        Raises:
            ValueError: If the number of images in the list exceeds the maximum batch size.

        Notes:
            The function measures performance using perf_counter and also has support for ONNX session to get embeddings.
        """
        t1 = perf_counter()

        if isinstance(image, list):
            if len(image) > CLIP_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
                )
            imgs = [self.preproc_image(i) for i in image]
            img_in = np.concatenate(imgs, axis=0)
        else:
            img_in = self.preproc_image(image)

        onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in}
        with self._visual_session_lock:
            return self.visual_onnx_session.run(None, onnx_input_image)[0]

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in}
        with self._visual_session_lock:
            embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0]
        return (embeddings,)

    def make_embed_image_response(
        self, embeddings: np.ndarray
    ) -> ClipEmbeddingResponse:
        """
        Converts the given embeddings into a ClipEmbeddingResponse object.

        Args:
            embeddings (np.ndarray): A numpy array containing the embeddings for an image or images.

        Returns:
            ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

        Example:
            >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
            >>> make_embed_image_response(embeddings_array)
            ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
        """
        response = ClipEmbeddingResponse(embeddings=embeddings.tolist())

        return response

    def embed_text(
        self,
        text: Union[str, List[str]],
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds a text or a list of texts using the Clip model.

        Args:
            text (Union[str, List[str]]): The text string or list of text strings to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the text or texts as a numpy array.

        Raises:
            ValueError: If the number of text strings in the list exceeds the maximum batch size.

        Notes:
            The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.
        """
        if isinstance(text, list):
            texts = text
        else:
            texts = [text]
        results = []
        for texts_batch in create_batches(
            sequence=texts, batch_size=CLIP_MAX_BATCH_SIZE
        ):
            tokenized_batch = clip.tokenize(texts_batch).numpy().astype(np.int32)
            onnx_input_text = {
                self.textual_onnx_session.get_inputs()[0].name: tokenized_batch
            }
            with self._textual_session_lock:
                embeddings = self.textual_onnx_session.run(None, onnx_input_text)[0]
            results.append(embeddings)
        return np.concatenate(results, axis=0)

    def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse:
        """
        Converts the given text embeddings into a ClipEmbeddingResponse object.

        Args:
            embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts.

        Returns:
            ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

        Example:
            >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
            >>> make_embed_text_response(embeddings_array)
            ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
        """
        response = ClipEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: The list of file names.
        """
        return ["textual.onnx", "visual.onnx"]

    def infer_from_request(
        self, request: ClipInferenceRequest
    ) -> ClipEmbeddingResponse:
        """Routes the request to the appropriate inference function.

        Args:
            request (ClipInferenceRequest): The request object containing the inference details.

        Returns:
            ClipEmbeddingResponse: The response object containing the embeddings.
        """
        t1 = perf_counter()
        if isinstance(request, ClipImageEmbeddingRequest):
            infer_func = self.embed_image
            make_response_func = self.make_embed_image_response
        elif isinstance(request, ClipTextEmbeddingRequest):
            infer_func = self.embed_text
            make_response_func = self.make_embed_text_response
        elif isinstance(request, ClipCompareRequest):
            infer_func = self.compare
            make_response_func = self.make_compare_response
        else:
            raise ValueError(
                f"Request type {type(request)} is not a valid ClipInferenceRequest"
            )
        data = infer_func(**request.dict())
        response = make_response_func(data)
        response.time = perf_counter() - t1
        return response

    def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse:
        return [self.make_embed_image_response(embeddings)]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return [self.make_embed_image_response(predictions[0])]

    def infer(self, image: Any, **kwargs) -> Any:
        """Embeds an image
        - image:
            can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        """
        return super().infer(image, **kwargs)

    def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
        """Preprocesses an inference request image.

        Args:
            image (InferenceRequestImage): The object containing information necessary to load the image for inference.

        Returns:
            np.ndarray: A numpy array of the preprocessed image pixel data.
        """
        pil_image = Image.fromarray(load_image_rgb(image))
        preprocessed_image = self.clip_preprocess(pil_image)

        img_in = np.expand_dims(preprocessed_image, axis=0)

        return img_in.astype(np.float32)

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        return self.preproc_image(image), PreprocessReturnMetadata({})
Functions
__init__
__init__(
    *args,
    model_id=CLIP_MODEL_ID,
    onnxruntime_execution_providers=get_onnxruntime_execution_providers(
        ONNXRUNTIME_EXECUTION_PROVIDERS
    ),
    **kwargs
)

Initializes the Clip with the given arguments and keyword arguments.

Source code in inference/models/clip/clip_model.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    *args,
    model_id: str = CLIP_MODEL_ID,
    onnxruntime_execution_providers: List[
        str
    ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS),
    **kwargs,
):
    """Initializes the Clip with the given arguments and keyword arguments."""
    self.onnxruntime_execution_providers = onnxruntime_execution_providers
    t1 = perf_counter()
    super().__init__(*args, model_id=model_id, **kwargs)
    # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
    self.log("Creating inference sessions")
    self.visual_onnx_session = onnxruntime.InferenceSession(
        self.cache_file("visual.onnx"),
        providers=self.onnxruntime_execution_providers,
    )
    self._visual_session_lock = Lock()
    self.textual_onnx_session = onnxruntime.InferenceSession(
        self.cache_file("textual.onnx"),
        providers=self.onnxruntime_execution_providers,
    )
    self._textual_session_lock = Lock()
    if REQUIRED_ONNX_PROVIDERS:
        available_providers = onnxruntime.get_available_providers()
        for provider in REQUIRED_ONNX_PROVIDERS:
            if provider not in available_providers:
                raise OnnxProviderNotAvailable(
                    f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device."
                )

    self.resolution = self.visual_onnx_session.get_inputs()[0].shape[2]

    self.clip_preprocess = clip.clip._transform(self.resolution)
    self.log(f"CLIP model loaded in {perf_counter() - t1:.2f} seconds")
    self.task_type = "embedding"
compare
compare(
    subject,
    prompt,
    subject_type="image",
    prompt_type="text",
    **kwargs
)

Compares the subject with the prompt to calculate similarity scores.

Parameters:

Name Type Description Default
subject Any

The subject data to be compared. Can be either an image or text.

required
prompt Any

The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.

required
subject_type str

Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".

'image'
prompt_type Union[str, List[str], Dict[str, Any]]

Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".

'text'
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[List[float], Dict[str, float]]

Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

Raises:

Type Description
ValueError

If subject_type or prompt_type is neither "image" nor "text".

ValueError

If the number of prompts exceeds the maximum batch size.

Source code in inference/models/clip/clip_model.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def compare(
    self,
    subject: Any,
    prompt: Any,
    subject_type: str = "image",
    prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
    **kwargs,
) -> Union[List[float], Dict[str, float]]:
    """
    Compares the subject with the prompt to calculate similarity scores.

    Args:
        subject (Any): The subject data to be compared. Can be either an image or text.
        prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
        subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
        prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
        **kwargs: Additional keyword arguments.

    Returns:
        Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys.

    Raises:
        ValueError: If subject_type or prompt_type is neither "image" nor "text".
        ValueError: If the number of prompts exceeds the maximum batch size.
    """

    if subject_type == "image":
        subject_embeddings = self.embed_image(subject)
    elif subject_type == "text":
        subject_embeddings = self.embed_text(subject)
    else:
        raise ValueError(
            "subject_type must be either 'image' or 'text', but got {request.subject_type}"
        )

    if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
        prompt_keys = prompt.keys()
        prompt = [prompt[k] for k in prompt_keys]
        prompt_obj = "dict"
    else:
        prompt = prompt
        if not isinstance(prompt, list):
            prompt = [prompt]
        prompt_obj = "list"

    if len(prompt) > CLIP_MAX_BATCH_SIZE:
        raise ValueError(
            f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
        )

    if prompt_type == "image":
        prompt_embeddings = self.embed_image(prompt)
    elif prompt_type == "text":
        prompt_embeddings = self.embed_text(prompt)
    else:
        raise ValueError(
            "prompt_type must be either 'image' or 'text', but got {request.prompt_type}"
        )

    similarities = [
        cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
    ]

    if prompt_obj == "dict":
        similarities = dict(zip(prompt_keys, similarities))

    return similarities
embed_image
embed_image(image, **kwargs)

Embeds an image or a list of images using the Clip model.

Parameters:

Name Type Description Default
image Any

The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the image(s) as a numpy array.

Raises:

Type Description
ValueError

If the number of images in the list exceeds the maximum batch size.

Notes

The function measures performance using perf_counter and also has support for ONNX session to get embeddings.

Source code in inference/models/clip/clip_model.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def embed_image(
    self,
    image: Any,
    **kwargs,
) -> np.ndarray:
    """
    Embeds an image or a list of images using the Clip model.

    Args:
        image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the image(s) as a numpy array.

    Raises:
        ValueError: If the number of images in the list exceeds the maximum batch size.

    Notes:
        The function measures performance using perf_counter and also has support for ONNX session to get embeddings.
    """
    t1 = perf_counter()

    if isinstance(image, list):
        if len(image) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
            )
        imgs = [self.preproc_image(i) for i in image]
        img_in = np.concatenate(imgs, axis=0)
    else:
        img_in = self.preproc_image(image)

    onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in}
    with self._visual_session_lock:
        return self.visual_onnx_session.run(None, onnx_input_image)[0]
embed_text
embed_text(text, **kwargs)

Embeds a text or a list of texts using the Clip model.

Parameters:

Name Type Description Default
text Union[str, List[str]]

The text string or list of text strings to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the text or texts as a numpy array.

Raises:

Type Description
ValueError

If the number of text strings in the list exceeds the maximum batch size.

Notes

The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.

Source code in inference/models/clip/clip_model.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def embed_text(
    self,
    text: Union[str, List[str]],
    **kwargs,
) -> np.ndarray:
    """
    Embeds a text or a list of texts using the Clip model.

    Args:
        text (Union[str, List[str]]): The text string or list of text strings to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the text or texts as a numpy array.

    Raises:
        ValueError: If the number of text strings in the list exceeds the maximum batch size.

    Notes:
        The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter.
    """
    if isinstance(text, list):
        texts = text
    else:
        texts = [text]
    results = []
    for texts_batch in create_batches(
        sequence=texts, batch_size=CLIP_MAX_BATCH_SIZE
    ):
        tokenized_batch = clip.tokenize(texts_batch).numpy().astype(np.int32)
        onnx_input_text = {
            self.textual_onnx_session.get_inputs()[0].name: tokenized_batch
        }
        with self._textual_session_lock:
            embeddings = self.textual_onnx_session.run(None, onnx_input_text)[0]
        results.append(embeddings)
    return np.concatenate(results, axis=0)
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Returns:

Type Description
List[str]

List[str]: The list of file names.

Source code in inference/models/clip/clip_model.py
298
299
300
301
302
303
304
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference.

    Returns:
        List[str]: The list of file names.
    """
    return ["textual.onnx", "visual.onnx"]
infer
infer(image, **kwargs)

Embeds an image - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Source code in inference/models/clip/clip_model.py
347
348
349
350
351
352
def infer(self, image: Any, **kwargs) -> Any:
    """Embeds an image
    - image:
        can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
    """
    return super().infer(image, **kwargs)
infer_from_request
infer_from_request(request)

Routes the request to the appropriate inference function.

Parameters:

Name Type Description Default
request ClipInferenceRequest

The request object containing the inference details.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

The response object containing the embeddings.

Source code in inference/models/clip/clip_model.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def infer_from_request(
    self, request: ClipInferenceRequest
) -> ClipEmbeddingResponse:
    """Routes the request to the appropriate inference function.

    Args:
        request (ClipInferenceRequest): The request object containing the inference details.

    Returns:
        ClipEmbeddingResponse: The response object containing the embeddings.
    """
    t1 = perf_counter()
    if isinstance(request, ClipImageEmbeddingRequest):
        infer_func = self.embed_image
        make_response_func = self.make_embed_image_response
    elif isinstance(request, ClipTextEmbeddingRequest):
        infer_func = self.embed_text
        make_response_func = self.make_embed_text_response
    elif isinstance(request, ClipCompareRequest):
        infer_func = self.compare
        make_response_func = self.make_compare_response
    else:
        raise ValueError(
            f"Request type {type(request)} is not a valid ClipInferenceRequest"
        )
    data = infer_func(**request.dict())
    response = make_response_func(data)
    response.time = perf_counter() - t1
    return response
make_compare_response
make_compare_response(similarities)

Creates a ClipCompareResponse object from the provided similarity data.

Parameters:

Name Type Description Default
similarities Union[List[float], Dict[str, float]]

A list or dictionary containing similarity scores.

required

Returns:

Name Type Description
ClipCompareResponse ClipCompareResponse

An instance of the ClipCompareResponse with the given similarity scores.

Example

Assuming ClipCompareResponse expects a dictionary of string-float pairs:

make_compare_response({"image1": 0.98, "image2": 0.76}) ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})

Source code in inference/models/clip/clip_model.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def make_compare_response(
    self, similarities: Union[List[float], Dict[str, float]]
) -> ClipCompareResponse:
    """
    Creates a ClipCompareResponse object from the provided similarity data.

    Args:
        similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores.

    Returns:
        ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores.

    Example:
        Assuming `ClipCompareResponse` expects a dictionary of string-float pairs:

        >>> make_compare_response({"image1": 0.98, "image2": 0.76})
        ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76})
    """
    response = ClipCompareResponse(similarity=similarities)
    return response
make_embed_image_response
make_embed_image_response(embeddings)

Converts the given embeddings into a ClipEmbeddingResponse object.

Parameters:

Name Type Description Default
embeddings ndarray

A numpy array containing the embeddings for an image or images.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

Example

embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) make_embed_image_response(embeddings_array) ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])

Source code in inference/models/clip/clip_model.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def make_embed_image_response(
    self, embeddings: np.ndarray
) -> ClipEmbeddingResponse:
    """
    Converts the given embeddings into a ClipEmbeddingResponse object.

    Args:
        embeddings (np.ndarray): A numpy array containing the embeddings for an image or images.

    Returns:
        ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

    Example:
        >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
        >>> make_embed_image_response(embeddings_array)
        ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]])
    """
    response = ClipEmbeddingResponse(embeddings=embeddings.tolist())

    return response
make_embed_text_response
make_embed_text_response(embeddings)

Converts the given text embeddings into a ClipEmbeddingResponse object.

Parameters:

Name Type Description Default
embeddings ndarray

A numpy array containing the embeddings for a text or texts.

required

Returns:

Name Type Description
ClipEmbeddingResponse ClipEmbeddingResponse

An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

Example

embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) make_embed_text_response(embeddings_array) ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])

Source code in inference/models/clip/clip_model.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse:
    """
    Converts the given text embeddings into a ClipEmbeddingResponse object.

    Args:
        embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts.

    Returns:
        ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list.

    Example:
        >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
        >>> make_embed_text_response(embeddings_array)
        ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
    """
    response = ClipEmbeddingResponse(embeddings=embeddings.tolist())
    return response
preproc_image
preproc_image(image)

Preprocesses an inference request image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The object containing information necessary to load the image for inference.

required

Returns:

Type Description
ndarray

np.ndarray: A numpy array of the preprocessed image pixel data.

Source code in inference/models/clip/clip_model.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
    """Preprocesses an inference request image.

    Args:
        image (InferenceRequestImage): The object containing information necessary to load the image for inference.

    Returns:
        np.ndarray: A numpy array of the preprocessed image pixel data.
    """
    pil_image = Image.fromarray(load_image_rgb(image))
    preprocessed_image = self.clip_preprocess(pil_image)

    img_in = np.expand_dims(preprocessed_image, axis=0)

    return img_in.astype(np.float32)

Functions

models/deep_lab_v3_plus

inference.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation

Classes

DeepLabV3PlusSemanticSegmentation

Bases: SemanticSegmentationBaseOnnxRoboflowInferenceModel

DeepLabV3Plus Semantic Segmentation ONNX Inference Model.

This class is responsible for performing semantic segmentation using the DeepLabV3Plus model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs inference on the given image using the ONNX session.

Source code in inference/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class DeepLabV3PlusSemanticSegmentation(
    SemanticSegmentationBaseOnnxRoboflowInferenceModel
):
    """DeepLabV3Plus Semantic Segmentation ONNX Inference Model.

    This class is responsible for performing semantic segmentation using the DeepLabV3Plus model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs inference on the given image using the ONNX session.
    """

    # match train params
    preprocess_means = [0.485, 0.456, 0.406]
    preprocess_stds = [0.229, 0.224, 0.225]

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the DeepLabV3Plus model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"
Attributes
weights_file property
weights_file

Gets the weights file for the DeepLabV3Plus model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

models/depth_anything_v3/architecture

inference.models.depth_anything_v3.architecture.da3

Classes

DepthAnything3Net

Bases: Module

Depth Anything 3 network for depth estimation. Simplified for single-view depth-only inference.

This network consists of: - Backbone: DinoV2 feature extractor - Head: DualDPT for depth prediction

Returns:

Type Description

Dictionary containing:

  • depth: Predicted depth map (B, H, W)
  • depth_conf: Depth confidence map (B, H, W)
Source code in inference/models/depth_anything_v3/architecture/da3.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class DepthAnything3Net(nn.Module):
    """
    Depth Anything 3 network for depth estimation.
    Simplified for single-view depth-only inference.

    This network consists of:
    - Backbone: DinoV2 feature extractor
    - Head: DualDPT for depth prediction

    Returns:
        Dictionary containing:
        - depth: Predicted depth map (B, H, W)
        - depth_conf: Depth confidence map (B, H, W)
    """

    PATCH_SIZE = 14

    def __init__(
        self,
        backbone_name: str,
        out_layers: list,
        alt_start: int,
        qknorm_start: int,
        rope_start: int,
        cat_token: bool,
        head_dim_in: int,
        head_output_dim: int,
        head_features: int,
        head_out_channels: list,
    ):
        """
        Initialize DepthAnything3Net.

        Args:
            backbone_name: DinoV2 backbone variant ("vits" or "vitb")
            out_layers: Layer indices to extract features from
            alt_start: Layer index to start alternating attention
            qknorm_start: Layer index to start QK normalization
            rope_start: Layer index to start RoPE
            cat_token: Whether to concatenate local and global tokens
            head_dim_in: Input dimension for the head
            head_output_dim: Output dimension for the head
            head_features: Feature dimension in the head
            head_out_channels: Output channel dimensions per stage
        """
        super().__init__()
        self.backbone = DinoV2(
            name=backbone_name,
            out_layers=out_layers,
            alt_start=alt_start,
            qknorm_start=qknorm_start,
            rope_start=rope_start,
            cat_token=cat_token,
        )
        self.head = DualDPT(
            dim_in=head_dim_in,
            output_dim=head_output_dim,
            features=head_features,
            out_channels=head_out_channels,
        )
        self.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass through the network.

        Args:
            x: Input images (B, N, 3, H, W) where N=1 for single-view

        Returns:
            Dictionary containing depth predictions
        """
        # Extract features using backbone
        feats, _ = self.backbone(x)
        H, W = x.shape[-2], x.shape[-1]

        # Process features through depth head
        with torch.autocast(device_type=x.device.type, enabled=False):
            output = self._process_depth_head(feats, H, W)

        return output

    def _process_depth_head(
        self, feats: list[torch.Tensor], H: int, W: int
    ) -> Dict[str, torch.Tensor]:
        """Process features through the depth prediction head."""
        return self.head(feats, H, W, patch_start_idx=0)
Functions
__init__
__init__(
    backbone_name,
    out_layers,
    alt_start,
    qknorm_start,
    rope_start,
    cat_token,
    head_dim_in,
    head_output_dim,
    head_features,
    head_out_channels,
)

Initialize DepthAnything3Net.

Parameters:

Name Type Description Default
backbone_name str

DinoV2 backbone variant ("vits" or "vitb")

required
out_layers list

Layer indices to extract features from

required
alt_start int

Layer index to start alternating attention

required
qknorm_start int

Layer index to start QK normalization

required
rope_start int

Layer index to start RoPE

required
cat_token bool

Whether to concatenate local and global tokens

required
head_dim_in int

Input dimension for the head

required
head_output_dim int

Output dimension for the head

required
head_features int

Feature dimension in the head

required
head_out_channels list

Output channel dimensions per stage

required
Source code in inference/models/depth_anything_v3/architecture/da3.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    backbone_name: str,
    out_layers: list,
    alt_start: int,
    qknorm_start: int,
    rope_start: int,
    cat_token: bool,
    head_dim_in: int,
    head_output_dim: int,
    head_features: int,
    head_out_channels: list,
):
    """
    Initialize DepthAnything3Net.

    Args:
        backbone_name: DinoV2 backbone variant ("vits" or "vitb")
        out_layers: Layer indices to extract features from
        alt_start: Layer index to start alternating attention
        qknorm_start: Layer index to start QK normalization
        rope_start: Layer index to start RoPE
        cat_token: Whether to concatenate local and global tokens
        head_dim_in: Input dimension for the head
        head_output_dim: Output dimension for the head
        head_features: Feature dimension in the head
        head_out_channels: Output channel dimensions per stage
    """
    super().__init__()
    self.backbone = DinoV2(
        name=backbone_name,
        out_layers=out_layers,
        alt_start=alt_start,
        qknorm_start=qknorm_start,
        rope_start=rope_start,
        cat_token=cat_token,
    )
    self.head = DualDPT(
        dim_in=head_dim_in,
        output_dim=head_output_dim,
        features=head_features,
        out_channels=head_out_channels,
    )
    self.device = (
        torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    )
forward
forward(x)

Forward pass through the network.

Parameters:

Name Type Description Default
x Tensor

Input images (B, N, 3, H, W) where N=1 for single-view

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary containing depth predictions

Source code in inference/models/depth_anything_v3/architecture/da3.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def forward(
    self,
    x: torch.Tensor,
) -> Dict[str, torch.Tensor]:
    """
    Forward pass through the network.

    Args:
        x: Input images (B, N, 3, H, W) where N=1 for single-view

    Returns:
        Dictionary containing depth predictions
    """
    # Extract features using backbone
    feats, _ = self.backbone(x)
    H, W = x.shape[-2], x.shape[-1]

    # Process features through depth head
    with torch.autocast(device_type=x.device.type, enabled=False):
        output = self._process_depth_head(feats, H, W)

    return output

inference.models.depth_anything_v3.architecture.dpt

Classes

FeatureFusionBlock

Bases: Module

Top-down fusion block

Source code in inference/models/depth_anything_v3/architecture/dpt.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class FeatureFusionBlock(nn.Module):
    """Top-down fusion block"""

    def __init__(
        self,
        features: int,
        activation: nn.Module,
        deconv: bool = False,
        bn: bool = False,
        expand: bool = False,
        align_corners: bool = True,
        size: Tuple[int, int] = None,
        has_residual: bool = True,
        groups: int = 1,
    ) -> None:
        super().__init__()
        self.align_corners = align_corners
        self.size = size
        self.has_residual = has_residual

        self.resConfUnit1 = (
            ResidualConvUnit(features, activation, bn, groups=groups)
            if has_residual
            else None
        )
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)

        out_features = (features // 2) if expand else features
        self.out_conv = nn.Conv2d(
            features, out_features, 1, 1, 0, bias=True, groups=groups
        )
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor:
        y = xs[0]
        if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
            y = self.skip_add.add(y, self.resConfUnit1(xs[1]))

        y = self.resConfUnit2(y)

        if (size is None) and (self.size is None):
            up_kwargs = {"scale_factor": 2}
        elif size is None:
            up_kwargs = {"size": self.size}
        else:
            up_kwargs = {"size": size}

        y = custom_interpolate(
            y, **up_kwargs, mode="bilinear", align_corners=self.align_corners
        )
        y = self.out_conv(y)
        return y

ResidualConvUnit

Bases: Module

Lightweight residual convolution block for fusion

Source code in inference/models/depth_anything_v3/architecture/dpt.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class ResidualConvUnit(nn.Module):
    """Lightweight residual convolution block for fusion"""

    def __init__(
        self, features: int, activation: nn.Module, bn: bool, groups: int = 1
    ) -> None:
        super().__init__()
        self.bn = bn
        self.groups = groups
        self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
        self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
        self.norm1 = None
        self.norm2 = None
        self.activation = activation
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.activation(x)
        out = self.conv1(out)
        if self.norm1 is not None:
            out = self.norm1(out)

        out = self.activation(out)
        out = self.conv2(out)
        if self.norm2 is not None:
            out = self.norm2(out)

        return self.skip_add.add(out, x)

Functions

inference.models.depth_anything_v3.architecture.dualdpt

Classes

DualDPT

Bases: Module

Dual-head DPT for dense prediction with an auxiliary head. Simplified for single-view depth estimation - only depth output is used.

Source code in inference/models/depth_anything_v3/architecture/dualdpt.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class DualDPT(nn.Module):
    """
    Dual-head DPT for dense prediction with an auxiliary head.
    Simplified for single-view depth estimation - only depth output is used.
    """

    def __init__(
        self,
        dim_in: int,
        *,
        patch_size: int = 14,
        output_dim: int = 2,
        activation: str = "exp",
        conf_activation: str = "expp1",
        features: int = 256,
        out_channels: Sequence[int] = (256, 512, 1024, 1024),
        pos_embed: bool = True,
        down_ratio: int = 1,
        aux_pyramid_levels: int = 4,
        aux_out1_conv_num: int = 5,
        head_names: Tuple[str, str] = ("depth", "ray"),
    ) -> None:
        super().__init__()

        self.patch_size = patch_size
        self.activation = activation
        self.conf_activation = conf_activation
        self.pos_embed = pos_embed
        self.down_ratio = down_ratio

        self.aux_levels = aux_pyramid_levels
        self.aux_out1_conv_num = aux_out1_conv_num

        self.head_main, self.head_aux = head_names

        self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)

        self.norm = nn.LayerNorm(dim_in)
        self.projects = nn.ModuleList(
            [
                nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0)
                for oc in out_channels
            ]
        )

        self.resize_layers = nn.ModuleList(
            [
                nn.ConvTranspose2d(
                    out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
                ),
                nn.ConvTranspose2d(
                    out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
                ),
                nn.Identity(),
                nn.Conv2d(
                    out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1
                ),
            ]
        )

        self.scratch = _make_scratch(list(out_channels), features, expand=False)

        # Main fusion chain
        self.scratch.refinenet1 = _make_fusion_block(features)
        self.scratch.refinenet2 = _make_fusion_block(features)
        self.scratch.refinenet3 = _make_fusion_block(features)
        self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)

        head_features_1 = features
        head_features_2 = 32
        self.scratch.output_conv1 = nn.Conv2d(
            head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
        )
        self.scratch.output_conv2 = nn.Sequential(
            nn.Conv2d(
                head_features_1 // 2,
                head_features_2,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
        )

        # Auxiliary fusion chain (for ray head - not used for inference but needed for weight loading)
        self.scratch.refinenet1_aux = _make_fusion_block(features)
        self.scratch.refinenet2_aux = _make_fusion_block(features)
        self.scratch.refinenet3_aux = _make_fusion_block(features)
        self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False)

        self.scratch.output_conv1_aux = nn.ModuleList(
            [self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)]
        )

        use_ln = True
        ln_seq = (
            [
                Permute((0, 2, 3, 1)),
                nn.LayerNorm(head_features_2),
                Permute((0, 3, 1, 2)),
            ]
            if use_ln
            else []
        )
        self.scratch.output_conv2_aux = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        head_features_1 // 2,
                        head_features_2,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    ),
                    *ln_seq,
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0),
                )
                for _ in range(self.aux_levels)
            ]
        )

    def forward(
        self,
        feats: List[torch.Tensor],
        H: int,
        W: int,
        patch_start_idx: int,
        chunk_size: int = 8,
    ) -> Dict[str, torch.Tensor]:
        B, S, N, C = feats[0][0].shape
        feats = [feat[0].reshape(B * S, N, C) for feat in feats]
        if chunk_size is None or chunk_size >= S:
            out_dict = self._forward_impl(feats, H, W, patch_start_idx)
            out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()}
            return out_dict
        out_dicts = []
        for s0 in range(0, B * S, chunk_size):
            s1 = min(s0 + chunk_size, B * S)
            out_dict = self._forward_impl(
                [feat[s0:s1] for feat in feats],
                H,
                W,
                patch_start_idx,
            )
            out_dicts.append(out_dict)
        out_dict = {
            k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0)
            for k in out_dicts[0].keys()
        }
        out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
        return out_dict

    def _forward_impl(
        self,
        feats: List[torch.Tensor],
        H: int,
        W: int,
        patch_start_idx: int,
    ) -> Dict[str, torch.Tensor]:
        B, _, C = feats[0].shape
        ph, pw = H // self.patch_size, W // self.patch_size
        resized_feats = []
        for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
            x = feats[take_idx][:, patch_start_idx:]
            x = self.norm(x)
            x = x.permute(0, 2, 1).reshape(B, C, ph, pw)

            x = self.projects[stage_idx](x)
            if self.pos_embed:
                x = self._add_pos_embed(x, W, H)
            x = self.resize_layers[stage_idx](x)
            resized_feats.append(x)

        # Only compute main fusion for depth (skip aux for inference)
        fused_main, _ = self._fuse(resized_feats)

        h_out = int(ph * self.patch_size / self.down_ratio)
        w_out = int(pw * self.patch_size / self.down_ratio)

        fused_main = custom_interpolate(
            fused_main, (h_out, w_out), mode="bilinear", align_corners=True
        )
        if self.pos_embed:
            fused_main = self._add_pos_embed(fused_main, W, H)

        main_logits = self.scratch.output_conv2(fused_main)
        fmap = main_logits.permute(0, 2, 3, 1)
        main_pred = self._apply_activation_single(fmap[..., :-1], self.activation)
        main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)

        return {
            self.head_main: main_pred.squeeze(-1),
            f"{self.head_main}_conf": main_conf,
        }

    def _fuse(
        self, feats: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        l1, l2, l3, l4 = feats

        l1_rn = self.scratch.layer1_rn(l1)
        l2_rn = self.scratch.layer2_rn(l2)
        l3_rn = self.scratch.layer3_rn(l3)
        l4_rn = self.scratch.layer4_rn(l4)

        out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
        aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
        aux_list: List[torch.Tensor] = []
        if self.aux_levels >= 4:
            aux_list.append(aux_out)

        out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
        aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:])
        if self.aux_levels >= 3:
            aux_list.append(aux_out)

        out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
        aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:])
        if self.aux_levels >= 2:
            aux_list.append(aux_out)

        out = self.scratch.refinenet1(out, l1_rn)
        aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn)
        aux_list.append(aux_out)

        out = self.scratch.output_conv1(out)
        aux_list = [
            self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)
        ]

        return out, aux_list

    def _add_pos_embed(
        self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
    ) -> torch.Tensor:
        pw, ph = x.shape[-1], x.shape[-2]
        pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
        pe = position_grid_to_embed(pe, x.shape[1]) * ratio
        pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
        return x + pe.to(x.dtype)

    def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential:
        if self.aux_out1_conv_num == 5:
            return nn.Sequential(
                nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
                nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
                nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
                nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
                nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
            )
        if self.aux_out1_conv_num == 3:
            return nn.Sequential(
                nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
                nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
                nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
            )
        if self.aux_out1_conv_num == 1:
            return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1))
        raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported")

    def _apply_activation_single(
        self, x: torch.Tensor, activation: str = "linear"
    ) -> torch.Tensor:
        act = activation.lower() if isinstance(activation, str) else activation
        if act == "exp":
            return torch.exp(x)
        if act == "expm1":
            return torch.expm1(x)
        if act == "expp1":
            return torch.exp(x) + 1
        if act == "relu":
            return torch.relu(x)
        if act == "sigmoid":
            return torch.sigmoid(x)
        if act == "softplus":
            return torch.nn.functional.softplus(x)
        if act == "tanh":
            return torch.tanh(x)
        return x

Functions

inference.models.depth_anything_v3.architecture.head_utils

Classes

Permute

Bases: Module

nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage.

Source code in inference/models/depth_anything_v3/architecture/head_utils.py
22
23
24
25
26
27
28
29
30
31
32
class Permute(nn.Module):
    """nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage."""

    dims: Tuple[int, ...]

    def __init__(self, dims: Tuple[int, ...]) -> None:
        super().__init__()
        self.dims = dims

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.permute(*self.dims)

Functions

create_uv_grid

create_uv_grid(
    width,
    height,
    aspect_ratio=None,
    dtype=None,
    device=None,
)

Create a normalized UV grid of shape (width, height, 2).

Source code in inference/models/depth_anything_v3/architecture/head_utils.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def create_uv_grid(
    width: int,
    height: int,
    aspect_ratio: float = None,
    dtype: torch.dtype = None,
    device: torch.device = None,
) -> torch.Tensor:
    """Create a normalized UV grid of shape (width, height, 2)."""
    if aspect_ratio is None:
        aspect_ratio = float(width) / float(height)

    diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
    span_x = aspect_ratio / diag_factor
    span_y = 1.0 / diag_factor

    left_x = -span_x * (width - 1) / width
    right_x = span_x * (width - 1) / width
    top_y = -span_y * (height - 1) / height
    bottom_y = span_y * (height - 1) / height

    x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
    y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)

    uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
    uv_grid = torch.stack((uu, vv), dim=-1)

    return uv_grid

custom_interpolate

custom_interpolate(
    x,
    size=None,
    scale_factor=None,
    mode="bilinear",
    align_corners=True,
)

Safe interpolation implementation to avoid INT_MAX overflow.

Source code in inference/models/depth_anything_v3/architecture/head_utils.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def custom_interpolate(
    x: torch.Tensor,
    size: Union[Tuple[int, int], None] = None,
    scale_factor: Union[float, None] = None,
    mode: str = "bilinear",
    align_corners: bool = True,
) -> torch.Tensor:
    """Safe interpolation implementation to avoid INT_MAX overflow."""
    if size is None:
        assert scale_factor is not None, "Either size or scale_factor must be provided."
        size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))

    INT_MAX = 1610612736
    total = size[0] * size[1] * x.shape[0] * x.shape[1]

    if total > INT_MAX:
        chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
        outs = [
            F.interpolate(c, size=size, mode=mode, align_corners=align_corners)
            for c in chunks
        ]
        return torch.cat(outs, dim=0).contiguous()

    return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)

make_sincos_pos_embed

make_sincos_pos_embed(embed_dim, pos, omega_0=100)

Generate 1D positional embedding from a given grid using sine and cosine functions.

Source code in inference/models/depth_anything_v3/architecture/head_utils.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def make_sincos_pos_embed(
    embed_dim: int, pos: torch.Tensor, omega_0: float = 100
) -> torch.Tensor:
    """Generate 1D positional embedding from a given grid using sine and cosine functions."""
    assert embed_dim % 2 == 0
    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
    omega /= embed_dim / 2.0
    omega = 1.0 / omega_0**omega

    pos = pos.reshape(-1)
    out = torch.einsum("m,d->md", pos, omega)

    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)

    emb = torch.cat([emb_sin, emb_cos], dim=1)
    return emb.float()

position_grid_to_embed

position_grid_to_embed(pos_grid, embed_dim, omega_0=100)

Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)

Source code in inference/models/depth_anything_v3/architecture/head_utils.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def position_grid_to_embed(
    pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100
) -> torch.Tensor:
    """
    Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
    """
    H, W, grid_dim = pos_grid.shape
    assert grid_dim == 2
    pos_flat = pos_grid.reshape(-1, grid_dim)

    emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
    emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)

    emb = torch.cat([emb_x, emb_y], dim=-1)

    return emb.view(H, W, embed_dim)

models/depth_anything_v3/architecture/layers

inference.models.depth_anything_v3.architecture.layers.drop_path

Classes

DropPath

Bases: Module

Drop paths (Stochastic Depth) per sample.

Source code in inference/models/depth_anything_v3/architecture/layers/drop_path.py
27
28
29
30
31
32
33
34
35
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample."""

    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

inference.models.depth_anything_v3.architecture.layers.patch_embed

Classes

PatchEmbed

Bases: Module

2D image to patch embedding: (B,C,H,W) -> (B,N,D)

Parameters:

Name Type Description Default
img_size Union[int, Tuple[int, int]]

Image size.

224
patch_size Union[int, Tuple[int, int]]

Patch token size.

16
in_chans int

Number of input image channels.

3
embed_dim int

Number of linear projection output channels.

768
norm_layer Optional[Callable]

Normalization layer.

None
Source code in inference/models/depth_anything_v3/architecture/layers/patch_embed.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class PatchEmbed(nn.Module):
    """
    2D image to patch embedding: (B,C,H,W) -> (B,N,D)

    Args:
        img_size: Image size.
        patch_size: Patch token size.
        in_chans: Number of input image channels.
        embed_dim: Number of linear projection output channels.
        norm_layer: Normalization layer.
    """

    def __init__(
        self,
        img_size: Union[int, Tuple[int, int]] = 224,
        patch_size: Union[int, Tuple[int, int]] = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: Optional[Callable] = None,
        flatten_embedding: bool = True,
    ) -> None:
        super().__init__()

        image_HW = make_2tuple(img_size)
        patch_HW = make_2tuple(patch_size)
        patch_grid_size = (
            image_HW[0] // patch_HW[0],
            image_HW[1] // patch_HW[1],
        )

        self.img_size = image_HW
        self.patch_size = patch_HW
        self.patches_resolution = patch_grid_size
        self.num_patches = patch_grid_size[0] * patch_grid_size[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.flatten_embedding = flatten_embedding

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        _, _, H, W = x.shape
        patch_H, patch_W = self.patch_size

        assert (
            H % patch_H == 0
        ), f"Input image height {H} is not a multiple of patch height {patch_H}"
        assert (
            W % patch_W == 0
        ), f"Input image width {W} is not a multiple of patch width: {patch_W}"

        x = self.proj(x)  # B C H W
        H, W = x.size(2), x.size(3)
        x = x.flatten(2).transpose(1, 2)  # B HW C
        x = self.norm(x)
        if not self.flatten_embedding:
            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
        return x

inference.models.depth_anything_v3.architecture.layers.rope

Classes

PositionGetter

Generates and caches 2D spatial positions for patches in a grid.

Source code in inference/models/depth_anything_v3/architecture/layers/rope.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class PositionGetter:
    """Generates and caches 2D spatial positions for patches in a grid."""

    def __init__(self):
        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(
        self, batch_size: int, height: int, width: int, device: torch.device
    ) -> torch.Tensor:
        if (height, width) not in self.position_cache:
            y_coords = torch.arange(height, device=device)
            x_coords = torch.arange(width, device=device)
            positions = torch.cartesian_prod(y_coords, x_coords)
            self.position_cache[height, width] = positions

        cached_positions = self.position_cache[height, width]
        return (
            cached_positions.view(1, height * width, 2)
            .expand(batch_size, -1, -1)
            .clone()
        )

RotaryPositionEmbedding2D

Bases: Module

2D Rotary Position Embedding implementation.

Source code in inference/models/depth_anything_v3/architecture/layers/rope.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class RotaryPositionEmbedding2D(nn.Module):
    """2D Rotary Position Embedding implementation."""

    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
        super().__init__()
        self.base_frequency = frequency
        self.scaling_factor = scaling_factor
        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}

    def _compute_frequency_components(
        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        cache_key = (dim, seq_len, device, dtype)
        if cache_key not in self.frequency_cache:
            exponents = torch.arange(0, dim, 2, device=device).float() / dim
            inv_freq = 1.0 / (self.base_frequency**exponents)

            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
            angles = torch.einsum("i,j->ij", positions, inv_freq)

            angles = angles.to(dtype)
            angles = torch.cat((angles, angles), dim=-1)
            cos_components = angles.cos().to(dtype)
            sin_components = angles.sin().to(dtype)
            self.frequency_cache[cache_key] = (cos_components, sin_components)

        return self.frequency_cache[cache_key]

    @staticmethod
    def _rotate_features(x: torch.Tensor) -> torch.Tensor:
        feature_dim = x.shape[-1]
        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_1d_rope(
        self,
        tokens: torch.Tensor,
        positions: torch.Tensor,
        cos_comp: torch.Tensor,
        sin_comp: torch.Tensor,
    ) -> torch.Tensor:
        cos = F.embedding(positions, cos_comp)[:, None, :, :]
        sin = F.embedding(positions, sin_comp)[:, None, :, :]
        return (tokens * cos) + (self._rotate_features(tokens) * sin)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
        assert (
            positions.ndim == 3 and positions.shape[-1] == 2
        ), "Positions must have shape (batch_size, n_tokens, 2)"

        feature_dim = tokens.size(-1) // 2

        max_position = int(positions.max()) + 1
        cos_comp, sin_comp = self._compute_frequency_components(
            feature_dim, max_position, tokens.device, tokens.dtype
        )

        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)

        vertical_features = self._apply_1d_rope(
            vertical_features, positions[..., 0], cos_comp, sin_comp
        )
        horizontal_features = self._apply_1d_rope(
            horizontal_features, positions[..., 1], cos_comp, sin_comp
        )

        return torch.cat((vertical_features, horizontal_features), dim=-1)

models/depth_anything_v3

inference.models.depth_anything_v3.depth_anything_v3

Classes

DepthAnythingV3

Bases: DepthAnythingV2

Depth Anything V3 model for monocular depth estimation.

This model uses the Depth Anything V3 architecture with DinoV2 backbone and DualDPT head for dense depth prediction.

Note: Unlike V2, V3 is not HuggingFace Transformers compatible, so the architecture is vendored in and model loading is custom. However, the external interface (inputs/outputs) matches V2.

Source code in inference/models/depth_anything_v3/depth_anything_v3.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
class DepthAnythingV3(DepthAnythingV2):
    """
    Depth Anything V3 model for monocular depth estimation.

    This model uses the Depth Anything V3 architecture with DinoV2 backbone
    and DualDPT head for dense depth prediction.

    Note: Unlike V2, V3 is not HuggingFace Transformers compatible, so the
    architecture is vendored in and model loading is custom. However, the
    external interface (inputs/outputs) matches V2.
    """

    endpoint = "depth-anything-v3/small"

    def __init__(self, *args, **kwargs):

        try:
            super().__init__(*args, **kwargs)
        except Exception as e:
            print(f"Error initializing depth estimation model: {str(e)}")
            raise

        # Set appropriate dtype based on device
        if self.device.type == "mps":
            self.model = self.model.to(torch.float32)  # MPS prefers float32
        elif self.device.type == "cpu":
            warnings.warn(
                "Running DepthAnythingV3 on CPU. This may be very slow. Consider using GPU or MPS if available."
            )

    def initialize_model(self, **kwargs):
        """Initialize the model with vendored architecture instead of HF Transformers."""
        # Determine device
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
            warnings.warn(
                "Running DepthAnythingV3 on CPU. This may be slow. "
                "Consider using GPU or MPS if available."
            )

        # Determine dtype
        if self.device.type == "cuda":
            self.dtype = (
                torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
            )
        elif self.device.type == "mps":
            self.dtype = torch.float32  # MPS works better with float32
        else:
            self.dtype = torch.float32

        # Load configuration from config.json
        config_path = self._get_config_path()
        self.config = parse_config(config_path)

        # Build model with vendored architecture
        self.model = DepthAnything3Net(**self.config)

        # Load weights
        self._load_weights()

        # Move model to device and set eval mode
        self.model = self.model.to(self.device, dtype=self.dtype)
        self.model.eval()

        # Load processor from cache dir (uses preprocessor_config.json)
        self.processor = AutoImageProcessor.from_pretrained(self.cache_dir)

    def _load_weights(self):
        """Load pretrained weights from the model cache."""
        weights_path = self._get_model_weights_path()

        if weights_path.endswith(".safetensors"):
            state_dict = load_safetensors(weights_path)
        else:
            state_dict = torch.load(weights_path, map_location="cpu")

        # Convert state dict format
        state_dict = convert_state_dict(state_dict)

        # Load weights (strict=False to handle missing aux weights)
        missing, unexpected = self.model.load_state_dict(state_dict, strict=False)

        # Filter out expected missing keys:
        # - cam_enc, cam_dec: Camera encoder/decoder (not used for depth-only)
        # - gs_head, gs_adapter: Gaussian splatting head (not used)
        # - output_conv2_aux: Auxiliary ray prediction heads (not used for depth-only)
        expected_missing = [
            "cam_enc",
            "cam_dec",
            "gs_head",
            "gs_adapter",
            "output_conv2_aux",
        ]
        unexpected_filtered = [
            k for k in unexpected if not any(skip in k for skip in expected_missing)
        ]
        missing_filtered = [
            k for k in missing if not any(skip in k for skip in expected_missing)
        ]

        if missing_filtered:
            warnings.warn(f"Missing keys when loading weights: {missing_filtered}")
        if unexpected_filtered:
            warnings.warn(
                f"Unexpected keys when loading weights: {unexpected_filtered}"
            )

    def _get_config_path(self) -> str:
        """Get path to model config file."""
        cache_dir = Path(self.cache_dir)
        config_file = cache_dir / "config.json"
        if config_file.exists():
            return str(config_file)
        raise FileNotFoundError(
            f"Could not find config.json in {cache_dir}. "
            f"Expected config.json to be downloaded alongside model weights."
        )

    def _get_model_weights_path(self) -> str:
        """Get path to model weights file."""
        cache_dir = Path(self.cache_dir)

        # Try weights.safetensors (common HF convention)
        weights_file = cache_dir / "model.safetensors"
        if weights_file.exists():
            return str(weights_file)
        else:
            raise FileNotFoundError(f"Could not find {weights_file} in {cache_dir}")

    def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
        """
        Run depth prediction on an input image.

        Unlike V2, the vendored DepthAnything3Net expects a tensor directly
        with shape (B, N, 3, H, W) where N=1 for single-view inference.
        """
        from inference.core.workflows.execution_engine.entities.base import (
            ImageParentMetadata,
            WorkflowImageData,
        )

        # Process input image using the HF processor
        inputs = self.processor(images=image_in, return_tensors="pt")

        # Extract pixel_values and add the N dimension
        # Processor outputs: (B, C, H, W) -> Model expects: (B, N, C, H, W)
        pixel_values = inputs["pixel_values"]
        pixel_values = pixel_values.unsqueeze(1)  # Add N=1 dimension

        # Move to device and dtype
        pixel_values = pixel_values.to(self.device, dtype=self.dtype)

        # Run inference
        with torch.inference_mode():
            outputs = self.model(pixel_values)

            # Extract depth from model output
            # Model returns dict with 'depth' key containing (B, S, H, W) tensor
            # where S=1 for single-view, so we squeeze it to (B, H, W)
            depth_map = outputs["depth"].squeeze(1)

            # Resize back to original image size
            depth_map = torch.nn.functional.interpolate(
                depth_map.unsqueeze(1),
                size=(image_in.height, image_in.width),
                mode="bilinear",
                align_corners=False,
            ).squeeze()

            depth_map = depth_map.to(torch.float32).cpu().numpy()

            # Normalize depth values
            depth_min = depth_map.min()
            depth_max = depth_map.max()
            if depth_max == depth_min:
                raise ValueError("Depth map has no variation (min equals max)")
            normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)
            normalized_depth = 1 - normalized_depth

            # Create visualization
            depth_for_viz = (normalized_depth * 255.0).astype(np.uint8)
            cmap = plt.get_cmap("viridis")
            colored_depth = (cmap(depth_for_viz)[:, :, :3] * 255).astype(np.uint8)

            # Convert numpy array to WorkflowImageData
            parent_metadata = ImageParentMetadata(parent_id=f"{uuid4()}")
            colored_depth_image = WorkflowImageData(
                numpy_image=colored_depth, parent_metadata=parent_metadata
            )

            result = {
                "image": colored_depth_image,
                "normalized_depth": normalized_depth,
            }

            return (result,)
Functions
initialize_model
initialize_model(**kwargs)

Initialize the model with vendored architecture instead of HF Transformers.

Source code in inference/models/depth_anything_v3/depth_anything_v3.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def initialize_model(self, **kwargs):
    """Initialize the model with vendored architecture instead of HF Transformers."""
    # Determine device
    if torch.cuda.is_available():
        self.device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        self.device = torch.device("mps")
    else:
        self.device = torch.device("cpu")
        warnings.warn(
            "Running DepthAnythingV3 on CPU. This may be slow. "
            "Consider using GPU or MPS if available."
        )

    # Determine dtype
    if self.device.type == "cuda":
        self.dtype = (
            torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        )
    elif self.device.type == "mps":
        self.dtype = torch.float32  # MPS works better with float32
    else:
        self.dtype = torch.float32

    # Load configuration from config.json
    config_path = self._get_config_path()
    self.config = parse_config(config_path)

    # Build model with vendored architecture
    self.model = DepthAnything3Net(**self.config)

    # Load weights
    self._load_weights()

    # Move model to device and set eval mode
    self.model = self.model.to(self.device, dtype=self.dtype)
    self.model.eval()

    # Load processor from cache dir (uses preprocessor_config.json)
    self.processor = AutoImageProcessor.from_pretrained(self.cache_dir)
predict
predict(image_in, prompt='', history=None, **kwargs)

Run depth prediction on an input image.

Unlike V2, the vendored DepthAnything3Net expects a tensor directly with shape (B, N, 3, H, W) where N=1 for single-view inference.

Source code in inference/models/depth_anything_v3/depth_anything_v3.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs):
    """
    Run depth prediction on an input image.

    Unlike V2, the vendored DepthAnything3Net expects a tensor directly
    with shape (B, N, 3, H, W) where N=1 for single-view inference.
    """
    from inference.core.workflows.execution_engine.entities.base import (
        ImageParentMetadata,
        WorkflowImageData,
    )

    # Process input image using the HF processor
    inputs = self.processor(images=image_in, return_tensors="pt")

    # Extract pixel_values and add the N dimension
    # Processor outputs: (B, C, H, W) -> Model expects: (B, N, C, H, W)
    pixel_values = inputs["pixel_values"]
    pixel_values = pixel_values.unsqueeze(1)  # Add N=1 dimension

    # Move to device and dtype
    pixel_values = pixel_values.to(self.device, dtype=self.dtype)

    # Run inference
    with torch.inference_mode():
        outputs = self.model(pixel_values)

        # Extract depth from model output
        # Model returns dict with 'depth' key containing (B, S, H, W) tensor
        # where S=1 for single-view, so we squeeze it to (B, H, W)
        depth_map = outputs["depth"].squeeze(1)

        # Resize back to original image size
        depth_map = torch.nn.functional.interpolate(
            depth_map.unsqueeze(1),
            size=(image_in.height, image_in.width),
            mode="bilinear",
            align_corners=False,
        ).squeeze()

        depth_map = depth_map.to(torch.float32).cpu().numpy()

        # Normalize depth values
        depth_min = depth_map.min()
        depth_max = depth_map.max()
        if depth_max == depth_min:
            raise ValueError("Depth map has no variation (min equals max)")
        normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)
        normalized_depth = 1 - normalized_depth

        # Create visualization
        depth_for_viz = (normalized_depth * 255.0).astype(np.uint8)
        cmap = plt.get_cmap("viridis")
        colored_depth = (cmap(depth_for_viz)[:, :, :3] * 255).astype(np.uint8)

        # Convert numpy array to WorkflowImageData
        parent_metadata = ImageParentMetadata(parent_id=f"{uuid4()}")
        colored_depth_image = WorkflowImageData(
            numpy_image=colored_depth, parent_metadata=parent_metadata
        )

        result = {
            "image": colored_depth_image,
            "normalized_depth": normalized_depth,
        }

        return (result,)

Functions

convert_state_dict

convert_state_dict(state_dict)

Convert state dict from official DA3 format to our simplified format.

Source code in inference/models/depth_anything_v3/depth_anything_v3.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def convert_state_dict(state_dict: dict) -> dict:
    """
    Convert state dict from official DA3 format to our simplified format.
    """
    new_state_dict = {}
    for key, value in state_dict.items():
        # Remove 'model.' prefix if present
        new_key = key
        if new_key.startswith("model."):
            new_key = new_key[6:]

        # Map backbone paths
        new_key = new_key.replace("net.", "backbone.")

        # Skip camera encoder/decoder weights (not used for depth-only inference)
        if "cam_enc" in new_key or "cam_dec" in new_key:
            continue

        # Skip GS head weights (not used)
        if "gs_head" in new_key or "gs_adapter" in new_key:
            continue

        new_state_dict[new_key] = value

    return new_state_dict

parse_config

parse_config(config_path)

Parse the config.json file from HuggingFace/official DA3 format.

Parameters:

Name Type Description Default
config_path str

Path to the config.json file

required

Returns:

Type Description
dict

Dictionary with model configuration parameters

Source code in inference/models/depth_anything_v3/depth_anything_v3.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def parse_config(config_path: str) -> dict:
    """
    Parse the config.json file from HuggingFace/official DA3 format.

    Args:
        config_path: Path to the config.json file

    Returns:
        Dictionary with model configuration parameters
    """
    with open(config_path, "r") as f:
        raw_config = json.load(f)

    config = raw_config.get("config", raw_config)

    # Extract backbone (net) configuration
    net_config = config.get("net", {})
    backbone_name = net_config.get("name", "vitb")
    out_layers = net_config.get("out_layers", [5, 7, 9, 11])
    alt_start = net_config.get("alt_start", 4)
    qknorm_start = net_config.get("qknorm_start", 4)
    rope_start = net_config.get("rope_start", 4)
    cat_token = net_config.get("cat_token", True)

    # Extract head configuration
    head_config = config.get("head", {})
    head_dim_in = head_config.get("dim_in", 1536)
    head_output_dim = head_config.get("output_dim", 2)
    head_features = head_config.get("features", 128)
    head_out_channels = head_config.get("out_channels", [96, 192, 384, 768])

    return {
        "backbone_name": backbone_name,
        "out_layers": out_layers,
        "alt_start": alt_start,
        "qknorm_start": qknorm_start,
        "rope_start": rope_start,
        "cat_token": cat_token,
        "head_dim_in": head_dim_in,
        "head_output_dim": head_output_dim,
        "head_features": head_features,
        "head_out_channels": head_out_channels,
    }

models/dinov3

inference.models.dinov3.dinov3_classification

Classes

DinoV3Classification

Bases: ClassificationBaseOnnxRoboflowInferenceModel

DinoV3Classification handles classification inference for Dinov3 linear probe models using ONNX.

Inherits

Attributes:

Name Type Description
multiclass bool

A flag that specifies if the model should handle multiclass classification.

Source code in inference/models/dinov3/dinov3_classification.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class DinoV3Classification(ClassificationBaseOnnxRoboflowInferenceModel):
    """DinoV3Classification handles classification inference
    for Dinov3 linear probe models using ONNX.

    Inherits:
        ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference.
        ClassificationMixin: Mixin class providing classification-specific methods.

    Attributes:
        multiclass (bool): A flag that specifies if the model should handle multiclass classification.
    """

    preprocess_means = [0.485, 0.456, 0.406]
    preprocess_stds = [0.229, 0.224, 0.225]

    def __init__(self, *args, **kwargs):
        """Initializes the DinoV3Classification instance.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.multiclass = self.environment.get("MULTICLASS", False)

    @property
    def weights_file(self) -> str:
        """Determines the weights file to be used based on the availability of AWS keys.

        If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'.
        Otherwise, it returns the path to 'best.onnx'.

        Returns:
            str: Path to the weights file.
        """
        if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA:
            return "weights.onnx"
        else:
            return "best.onnx"
Attributes
weights_file property
weights_file

Determines the weights file to be used based on the availability of AWS keys.

If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. Otherwise, it returns the path to 'best.onnx'.

Returns:

Name Type Description
str str

Path to the weights file.

Functions
__init__
__init__(*args, **kwargs)

Initializes the DinoV3Classification instance.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/dinov3/dinov3_classification.py
22
23
24
25
26
27
28
29
30
def __init__(self, *args, **kwargs):
    """Initializes the DinoV3Classification instance.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self.multiclass = self.environment.get("MULTICLASS", False)

models/doctr

inference.models.doctr.doctr_model

Classes

DocTR

Bases: RoboflowCoreModel

Source code in inference/models/doctr/doctr_model.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class DocTR(RoboflowCoreModel):
    def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs):
        """Initializes the DocTR model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        self.api_key = kwargs.get("api_key")
        self.dataset_id = "doctr"
        self.version_id = "default"
        self.endpoint = model_id
        model_id = model_id.lower()

        self.det_model = DocTRDet(api_key=kwargs.get("api_key"))
        self.rec_model = DocTRRec(api_key=kwargs.get("api_key"))

        os.makedirs(f"{MODEL_CACHE_DIR}/doctr/models/", exist_ok=True)

        detector_weights_path = (
            f"{MODEL_CACHE_DIR}/doctr/models/{self.det_model.version_id}.pt"
        )
        shutil.copyfile(
            f"{MODEL_CACHE_DIR}/doctr_det/{self.det_model.version_id}/model.pt",
            detector_weights_path,
        )
        recognizer_weights_path = (
            f"{MODEL_CACHE_DIR}/doctr/models/{self.rec_model.version_id}.pt"
        )
        shutil.copyfile(
            f"{MODEL_CACHE_DIR}/doctr_rec/{self.rec_model.version_id}/model.pt",
            recognizer_weights_path,
        )

        det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
        det_model.load_state_dict(
            torch.load(detector_weights_path, map_location=DEVICE, weights_only=True)
        )

        reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
        reco_model.load_state_dict(
            torch.load(recognizer_weights_path, map_location=DEVICE, weights_only=True)
        )

        self.model = ocr_predictor(
            det_arch=det_model,
            reco_arch=reco_model,
            pretrained=False,
        )
        self.task_type = "ocr"

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        self.det_model.clear_cache(delete_from_disk=delete_from_disk)
        self.rec_model.clear_cache(delete_from_disk=delete_from_disk)

    def preprocess_image(self, image: Image.Image) -> Image.Image:
        """
        DocTR pre-processes images as part of its inference pipeline.

        Thus, no preprocessing is required here.
        """
        pass

    def infer_from_request(
        self, request: DoctrOCRInferenceRequest
    ) -> Union[OCRInferenceResponse, List]:
        if type(request.image) is list:
            response = []
            request_copy = copy.copy(request)
            for image in request.image:
                request_copy.image = image
                response.append(self.single_request(request=request_copy))
            return response
        return self.single_request(request)

    def single_request(self, request: DoctrOCRInferenceRequest) -> OCRInferenceResponse:
        t1 = perf_counter()
        result = self.infer(**request.dict())
        if not isinstance(result, tuple):
            result = (result, None, None)
        # maintaining backwards compatibility with previous implementation
        if request.generate_bounding_boxes:
            return OCRInferenceResponse(
                result=result[0],
                image=result[1],
                predictions=result[2],
                time=perf_counter() - t1,
            )
        else:
            return OCRInferenceResponse(
                result=result[0],
                time=perf_counter() - t1,
            )

    def infer(
        self, image: Any, **kwargs
    ) -> Union[
        str, Tuple[str, InferenceResponseImage, List[ObjectDetectionPrediction]]
    ]:
        """
        Run inference on a provided image.
            - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

        Args:
            request (DoctrOCRInferenceRequest): The inference request.

        Returns:
            OCRInferenceResponse: The inference response.
        """

        img = load_image(image)

        with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
            image = Image.fromarray(img[0])

            image.save(f.name)

            doc = DocumentFile.from_images([f.name])

            result = self.model(doc).export()

            blocks = result["pages"][0]["blocks"]
            page_dimensions = result["pages"][0]["dimensions"]

            words = [
                word
                for block in blocks
                for line in block["lines"]
                for word in line["words"]
            ]

            result = " ".join([word["value"] for word in words])
            # maintaining backwards compatibility with previous implementation
            if not kwargs.get("generate_bounding_boxes", False):
                return result

            bounding_boxes = [
                _geometry_to_bbox(page_dimensions, word["geometry"]) for word in words
            ]
            objects = [
                ObjectDetectionPrediction(
                    **{
                        "x": bbox[0] + (bbox[2] - bbox[0]) // 2,
                        "y": bbox[1] + (bbox[3] - bbox[1]) // 2,
                        "width": bbox[2] - bbox[0],
                        "height": bbox[3] - bbox[1],
                        "confidence": float(word["objectness_score"]),
                        "class": word["value"],
                        "class_id": 0,
                        "detection_id": str(uuid.uuid4()),
                    }
                )
                for word, bbox in zip(words, bounding_boxes)
            ]
            image_height, image_width = img[0].shape[:2]
            return (
                result,
                InferenceResponseImage(width=image_width, height=image_height),
                objects,
            )

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["model.pt"]
Functions
__init__
__init__(
    *args, model_id="doctr_rec/crnn_vgg16_bn", **kwargs
)

Initializes the DocTR model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/doctr/doctr_model.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs):
    """Initializes the DocTR model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    self.api_key = kwargs.get("api_key")
    self.dataset_id = "doctr"
    self.version_id = "default"
    self.endpoint = model_id
    model_id = model_id.lower()

    self.det_model = DocTRDet(api_key=kwargs.get("api_key"))
    self.rec_model = DocTRRec(api_key=kwargs.get("api_key"))

    os.makedirs(f"{MODEL_CACHE_DIR}/doctr/models/", exist_ok=True)

    detector_weights_path = (
        f"{MODEL_CACHE_DIR}/doctr/models/{self.det_model.version_id}.pt"
    )
    shutil.copyfile(
        f"{MODEL_CACHE_DIR}/doctr_det/{self.det_model.version_id}/model.pt",
        detector_weights_path,
    )
    recognizer_weights_path = (
        f"{MODEL_CACHE_DIR}/doctr/models/{self.rec_model.version_id}.pt"
    )
    shutil.copyfile(
        f"{MODEL_CACHE_DIR}/doctr_rec/{self.rec_model.version_id}/model.pt",
        recognizer_weights_path,
    )

    det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
    det_model.load_state_dict(
        torch.load(detector_weights_path, map_location=DEVICE, weights_only=True)
    )

    reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
    reco_model.load_state_dict(
        torch.load(recognizer_weights_path, map_location=DEVICE, weights_only=True)
    )

    self.model = ocr_predictor(
        det_arch=det_model,
        reco_arch=reco_model,
        pretrained=False,
    )
    self.task_type = "ocr"
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["model.pt"].

Source code in inference/models/doctr/doctr_model.py
210
211
212
213
214
215
216
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["model.pt"].
    """
    return ["model.pt"]
infer
infer(image, **kwargs)

Run inference on a provided image. - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Parameters:

Name Type Description Default
request DoctrOCRInferenceRequest

The inference request.

required

Returns:

Name Type Description
OCRInferenceResponse Union[str, Tuple[str, InferenceResponseImage, List[ObjectDetectionPrediction]]]

The inference response.

Source code in inference/models/doctr/doctr_model.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def infer(
    self, image: Any, **kwargs
) -> Union[
    str, Tuple[str, InferenceResponseImage, List[ObjectDetectionPrediction]]
]:
    """
    Run inference on a provided image.
        - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

    Args:
        request (DoctrOCRInferenceRequest): The inference request.

    Returns:
        OCRInferenceResponse: The inference response.
    """

    img = load_image(image)

    with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
        image = Image.fromarray(img[0])

        image.save(f.name)

        doc = DocumentFile.from_images([f.name])

        result = self.model(doc).export()

        blocks = result["pages"][0]["blocks"]
        page_dimensions = result["pages"][0]["dimensions"]

        words = [
            word
            for block in blocks
            for line in block["lines"]
            for word in line["words"]
        ]

        result = " ".join([word["value"] for word in words])
        # maintaining backwards compatibility with previous implementation
        if not kwargs.get("generate_bounding_boxes", False):
            return result

        bounding_boxes = [
            _geometry_to_bbox(page_dimensions, word["geometry"]) for word in words
        ]
        objects = [
            ObjectDetectionPrediction(
                **{
                    "x": bbox[0] + (bbox[2] - bbox[0]) // 2,
                    "y": bbox[1] + (bbox[3] - bbox[1]) // 2,
                    "width": bbox[2] - bbox[0],
                    "height": bbox[3] - bbox[1],
                    "confidence": float(word["objectness_score"]),
                    "class": word["value"],
                    "class_id": 0,
                    "detection_id": str(uuid.uuid4()),
                }
            )
            for word, bbox in zip(words, bounding_boxes)
        ]
        image_height, image_width = img[0].shape[:2]
        return (
            result,
            InferenceResponseImage(width=image_width, height=image_height),
            objects,
        )
preprocess_image
preprocess_image(image)

DocTR pre-processes images as part of its inference pipeline.

Thus, no preprocessing is required here.

Source code in inference/models/doctr/doctr_model.py
104
105
106
107
108
109
110
def preprocess_image(self, image: Image.Image) -> Image.Image:
    """
    DocTR pre-processes images as part of its inference pipeline.

    Thus, no preprocessing is required here.
    """
    pass

DocTRDet

Bases: RoboflowCoreModel

DocTR class for document Optical Character Recognition (OCR).

Attributes:

Name Type Description
doctr

The DocTR model.

ort_session

ONNX runtime inference session.

Source code in inference/models/doctr/doctr_model.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class DocTRDet(RoboflowCoreModel):
    """DocTR class for document Optical Character Recognition (OCR).

    Attributes:
        doctr: The DocTR model.
        ort_session: ONNX runtime inference session.
    """

    def __init__(self, *args, model_id: str = "doctr_det/db_resnet50_v2", **kwargs):
        """Initializes the DocTR model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """

        self.get_infer_bucket_file_list()

        super().__init__(*args, model_id=model_id, **kwargs)

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        super().clear_cache(delete_from_disk=delete_from_disk)

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["model.pt"]
Functions
__init__
__init__(
    *args, model_id="doctr_det/db_resnet50_v2", **kwargs
)

Initializes the DocTR model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/doctr/doctr_model.py
251
252
253
254
255
256
257
258
259
260
261
def __init__(self, *args, model_id: str = "doctr_det/db_resnet50_v2", **kwargs):
    """Initializes the DocTR model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """

    self.get_infer_bucket_file_list()

    super().__init__(*args, model_id=model_id, **kwargs)
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["model.pt"].

Source code in inference/models/doctr/doctr_model.py
266
267
268
269
270
271
272
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["model.pt"].
    """
    return ["model.pt"]

DocTRRec

Bases: RoboflowCoreModel

Source code in inference/models/doctr/doctr_model.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class DocTRRec(RoboflowCoreModel):
    def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn_v2", **kwargs):
        """Initializes the DocTR model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        self.get_infer_bucket_file_list()

        super().__init__(*args, model_id=model_id, **kwargs)

    def clear_cache(self, delete_from_disk: bool = True) -> None:
        super().clear_cache(delete_from_disk=delete_from_disk)

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["model.pt"]
Functions
__init__
__init__(
    *args, model_id="doctr_rec/crnn_vgg16_bn_v2", **kwargs
)

Initializes the DocTR model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/doctr/doctr_model.py
220
221
222
223
224
225
226
227
228
229
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn_v2", **kwargs):
    """Initializes the DocTR model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    self.get_infer_bucket_file_list()

    super().__init__(*args, model_id=model_id, **kwargs)
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["model.pt"].

Source code in inference/models/doctr/doctr_model.py
234
235
236
237
238
239
240
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["model.pt"].
    """
    return ["model.pt"]

Functions

models/easy_ocr

inference.models.easy_ocr.easy_ocr

Classes

EasyOCR

Bases: RoboflowCoreModel

Roboflow EasyOCR model implementation.

This class is responsible for handling the EasyOCR model, including loading the model, preprocessing the input, and performing inference.

Source code in inference/models/easy_ocr/easy_ocr.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class EasyOCR(RoboflowCoreModel):
    """Roboflow EasyOCR model implementation.

    This class is responsible for handling the EasyOCR model, including
    loading the model, preprocessing the input, and performing inference.
    """

    def __init__(
        self,
        model_id: str = "easy_ocr/english_g2",
        device: str = DEVICE,
        *args,
        **kwargs,
    ):
        """Initializes EasyOCR with the given arguments and keyword arguments."""

        super().__init__(model_id=model_id.lower(), *args, **kwargs)
        self.device = device
        self.task_type = "ocr"
        self.recognizer = model_id.split("/")[1]

        shutil.copyfile(
            f"{MODEL_CACHE_DIR}/{model_id}/weights.pt",
            f"{MODEL_CACHE_DIR}/{model_id}/{self.recognizer}.pth",
        )

    def predict(self, image_in: np.ndarray, prompt="", history=None, **kwargs):
        language_codes = kwargs.get("language_codes", ["en"])
        quantize = kwargs.get("quantize", False)
        reader = easyocr.Reader(
            language_codes,
            download_enabled=False,
            user_network_directory=f"{MODEL_CACHE_DIR}/easy_ocr/{self.recognizer}/",
            model_storage_directory=f"{MODEL_CACHE_DIR}/easy_ocr/{self.recognizer}/",
            detect_network="craft",
            recog_network=self.recognizer,
            detector=True,
            recognizer=True,
            gpu=True,
            quantize=quantize,
        )

        results = reader.readtext(image_in)
        # convert native EasyOCR results from numpy to standard python types
        results = [
            (
                [
                    [x.item() if not isinstance(x, (int, float)) else x for x in c]
                    for c in res[0]
                ],
                res[1],
                res[2].item() if not isinstance(res[2], (int, float)) else res[2],
            )
            for res in results
        ]

        return results

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return predictions, preprocess_return_metadata

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        image = load_image(image)[0]
        return image, InferenceResponseImage(
            width=image.shape[1], height=image.shape[0]
        )

    def infer_from_request(
        self, request: EasyOCRInferenceRequest
    ) -> Union[OCRInferenceResponse, List]:
        if type(request.image) is list:
            response = []
            request_copy = copy.copy(request)
            for image in request.image:
                request_copy.image = image
                response.append(self.single_request(request=request_copy))
            return response
        return self.single_request(request)

    def single_request(self, request: EasyOCRInferenceRequest) -> OCRInferenceResponse:
        t1 = perf_counter()
        prediction_result, image_metadata = self.infer(**request.dict())
        strings = [res[1] for res in prediction_result]
        return OCRInferenceResponse(
            result=" ".join(strings),
            image=image_metadata,
            predictions=[
                ObjectDetectionPrediction(
                    **{
                        "x": box[0][0] + (box[2][0] - box[0][0]) // 2,
                        "y": box[0][1] + (box[2][1] - box[0][1]) // 2,
                        "width": box[2][0] - box[0][0],
                        "height": box[2][1] - box[0][1],
                        "confidence": float(confidence),
                        "class": string,
                        "class_id": 0,
                        "detection_id": str(uuid.uuid4()),
                    }
                )
                for box, string, confidence in prediction_result
            ],
            time=perf_counter() - t1,
        )

    def get_infer_bucket_file_list(self) -> List[str]:
        return ["weights.pt", "craft_mlt_25k.pth"]
Functions
__init__
__init__(
    model_id="easy_ocr/english_g2",
    device=DEVICE,
    *args,
    **kwargs
)

Initializes EasyOCR with the given arguments and keyword arguments.

Source code in inference/models/easy_ocr/easy_ocr.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    model_id: str = "easy_ocr/english_g2",
    device: str = DEVICE,
    *args,
    **kwargs,
):
    """Initializes EasyOCR with the given arguments and keyword arguments."""

    super().__init__(model_id=model_id.lower(), *args, **kwargs)
    self.device = device
    self.task_type = "ocr"
    self.recognizer = model_id.split("/")[1]

    shutil.copyfile(
        f"{MODEL_CACHE_DIR}/{model_id}/weights.pt",
        f"{MODEL_CACHE_DIR}/{model_id}/{self.recognizer}.pth",
    )

Functions

inference.models.easy_ocr.easy_ocr_inference_models

Classes

InferenceModelsEasyOCRAdapter

Bases: Model

Roboflow EasyOCR model implementation.

This class is responsible for handling the EasyOCR model, including loading the model, preprocessing the input, and performing inference.

Source code in inference/models/easy_ocr/easy_ocr_inference_models.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class InferenceModelsEasyOCRAdapter(Model):
    """Roboflow EasyOCR model implementation.

    This class is responsible for handling the EasyOCR model, including
    loading the model, preprocessing the input, and performing inference.
    """

    def __init__(
        self, model_id: str = "easy_ocr/english_g2", api_key: str = None, **kwargs
    ):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY

        self.task_type = "ocr"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: EasyOCRTorch = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )

    def predict(self, image_in: np.ndarray, **kwargs) -> Tuple[str, Detections]:
        parsed_texts, parsed_structures = self._model.infer(images=image_in, **kwargs)
        parsed_text = parsed_texts[0]
        parsed_structure = parsed_structures[0]
        return parsed_text, parsed_structure

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return predictions, preprocess_return_metadata

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, InferenceResponseImage]:
        image = load_image_bgr(image)
        return image, InferenceResponseImage(
            width=image.shape[1], height=image.shape[0]
        )

    def infer_from_request(
        self, request: EasyOCRInferenceRequest
    ) -> Union[OCRInferenceResponse, List]:
        if type(request.image) is list:
            response = []
            request_copy = copy.copy(request)
            for image in request.image:
                request_copy.image = image
                response.append(self.single_request(request=request_copy))
            return response
        return self.single_request(request)

    def single_request(self, request: EasyOCRInferenceRequest) -> OCRInferenceResponse:
        t1 = perf_counter()
        kwargs = request.dict()
        kwargs["confidence"] = 0.0
        prediction_result, image_metadata = self.infer(**kwargs)
        predictions_for_image = []
        for instance_id in range(prediction_result[1].xyxy.shape[0]):
            x_min, y_min, x_max, y_max = prediction_result[1].xyxy[instance_id].tolist()
            width = x_max - x_min
            height = y_max - y_min
            center_x = (x_min + x_max) / 2
            center_y = (y_min + y_max) / 2
            predictions_for_image.append(
                ObjectDetectionPrediction(
                    # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                    **{
                        "x": center_x,
                        "y": center_y,
                        "width": width,
                        "height": height,
                        "confidence": 1.0,  # confidence is not returned by the model
                        "class": prediction_result[1].bboxes_metadata[instance_id][
                            "text"
                        ],
                        "class_id": 0,  # you can only prompt for one object at once
                        "detection_id": str(uuid.uuid4()),
                    }
                )
            )
        return OCRInferenceResponse(
            result=prediction_result[0],
            image=image_metadata,
            predictions=predictions_for_image,
            time=perf_counter() - t1,
        )

models/florence2

inference.models.florence2.utils

Functions

import_class_from_file

import_class_from_file(
    file_path, class_name, alias_name=None
)

Emulates what huggingface transformers does to load remote code with trust_remote_code=True, but allows us to use the class directly so that we don't have to load untrusted code.

Source code in inference/models/florence2/utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def import_class_from_file(file_path, class_name, alias_name=None):
    """
    Emulates what huggingface transformers does to load remote code with trust_remote_code=True,
    but allows us to use the class directly so that we don't have to load untrusted code.
    """
    file_path = os.path.abspath(file_path)
    module_name = os.path.splitext(os.path.basename(file_path))[0]
    module_dir = os.path.dirname(file_path)
    parent_dir = os.path.dirname(module_dir)

    sys.path.insert(0, parent_dir)

    previous_module = sys.modules.get(module_name)
    injected = False
    try:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)

        sys.modules[module_name] = module
        injected = True

        # Manually set the __package__ attribute to the parent package
        module.__package__ = os.path.basename(module_dir)

        spec.loader.exec_module(module)
        cls = getattr(module, class_name)
        if alias_name:
            globals()[alias_name] = cls
        return cls
    except Exception:
        if injected:
            if previous_module is not None:
                sys.modules[module_name] = previous_module
            else:
                sys.modules.pop(module_name, None)
        raise
    finally:
        sys.path.pop(0)

models/gaze

inference.models.gaze.gaze

Classes

Gaze

Bases: OnnxRoboflowCoreModel

Roboflow ONNX Gaze model.

This class is responsible for handling the ONNX Gaze model, including loading the model, preprocessing the input, and performing inference.

Attributes:

Name Type Description
gaze_onnx_session InferenceSession

ONNX Runtime session for gaze detection inference.

Source code in inference/models/gaze/gaze.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
class Gaze(OnnxRoboflowCoreModel):
    """Roboflow ONNX Gaze model.

    This class is responsible for handling the ONNX Gaze model, including
    loading the model, preprocessing the input, and performing inference.

    Attributes:
        gaze_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for gaze detection inference.
    """

    def __init__(self, *args, **kwargs):
        """Initializes the Gaze with the given arguments and keyword arguments."""

        t1 = perf_counter()
        super().__init__(*args, **kwargs)
        # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
        self.log("Creating inference sessions")

        # TODO: convert face detector (TensorflowLite) to ONNX model

        self.gaze_onnx_session = onnxruntime.InferenceSession(
            self.cache_file("L2CSNet_gaze360_resnet50_90bins.onnx"),
            providers=[
                (
                    "TensorrtExecutionProvider",
                    {
                        "trt_engine_cache_enable": True,
                        "trt_engine_cache_path": TENSORRT_CACHE_PATH,
                    },
                ),
                "CUDAExecutionProvider",
                "OpenVINOExecutionProvider",
                "CPUExecutionProvider",
            ],
        )
        self._gaze_session_lock = Lock()

        if REQUIRED_ONNX_PROVIDERS:
            available_providers = onnxruntime.get_available_providers()
            for provider in REQUIRED_ONNX_PROVIDERS:
                if provider not in available_providers:
                    raise OnnxProviderNotAvailable(
                        f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device."
                    )

        # init face detector
        self.face_detector = mp.tasks.vision.FaceDetector.create_from_options(
            mp.tasks.vision.FaceDetectorOptions(
                base_options=mp.tasks.BaseOptions(
                    model_asset_path=self.cache_file("mediapipe_face_detector.tflite")
                ),
                running_mode=mp.tasks.vision.RunningMode.IMAGE,
            )
        )

        # additional settings for gaze detection
        self._gaze_transformations = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize(448),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        self.task_type = "gaze-detection"
        self.log(f"GAZE model loaded in {perf_counter() - t1:.2f} seconds")

    def _crop_face_img(self, np_img: np.ndarray, face: Detection) -> np.ndarray:
        """Extract facial area in an image.

        Args:
            np_img (np.ndarray): The numpy image.
            face (mediapipe.tasks.python.components.containers.detections.Detection): The detected face.

        Returns:
            np.ndarray: Cropped face image.
        """
        # extract face area
        bbox = face.bounding_box
        x_min = bbox.origin_x
        y_min = bbox.origin_y
        x_max = bbox.origin_x + bbox.width
        y_max = bbox.origin_y + bbox.height
        face_img = np_img[y_min:y_max, x_min:x_max, :]
        face_img = cv2.resize(face_img, (224, 224))
        return face_img

    def _detect_gaze(self, np_imgs: List[np.ndarray]) -> List[Tuple[float, float]]:
        """Detect faces and gazes in an image.

        Args:
            pil_imgs (List[np.ndarray]): The numpy image list, each image is a cropped facial image.

        Returns:
            List[Tuple[float, float]]: Yaw (radian) and Pitch (radian).
        """
        ret = []
        for i in range(0, len(np_imgs), GAZE_MAX_BATCH_SIZE):
            img_batch = []
            for j in range(i, min(len(np_imgs), i + GAZE_MAX_BATCH_SIZE)):
                img = self._gaze_transformations(np_imgs[j])
                img = np.expand_dims(img, axis=0).astype(np.float32)
                img_batch.append(img)

            img_batch = np.concatenate(img_batch, axis=0)
            onnx_input_image = {self.gaze_onnx_session.get_inputs()[0].name: img_batch}
            with self._gaze_session_lock:
                yaw, pitch = self.gaze_onnx_session.run(None, onnx_input_image)
            for j in range(len(img_batch)):
                ret.append((yaw[j], pitch[j]))

        return ret

    def _make_response(
        self,
        faces: List[Detection],
        gazes: List[Tuple[float, float]],
        imgW: int,
        imgH: int,
        time_total: float,
        time_face_det: float = None,
        time_gaze_det: float = None,
    ) -> GazeDetectionInferenceResponse:
        """Prepare response object from detected faces and corresponding gazes.

        Args:
            faces (List[Detection]): The detected faces.
            gazes (List[tuple(float, float)]): The detected gazes (yaw, pitch).
            imgW (int): The width (px) of original image.
            imgH (int): The height (px) of original image.
            time_total (float): The processing time.
            time_face_det (float): The processing time.
            time_gaze_det (float): The processing time.

        Returns:
            GazeDetectionInferenceResponse: The response object including the detected faces and gazes info.
        """
        predictions = []
        for face, gaze in zip(faces, gazes):
            landmarks = []
            for keypoint in face.keypoints:
                x = min(max(int(keypoint.x * imgW), 0), imgW - 1)
                y = min(max(int(keypoint.y * imgH), 0), imgH - 1)
                landmarks.append(Point(x=x, y=y))

            bbox = face.bounding_box
            x_center = bbox.origin_x + bbox.width / 2
            y_center = bbox.origin_y + bbox.height / 2
            score = face.categories[0].score

            prediction = GazeDetectionPrediction(
                face=FaceDetectionPrediction(
                    x=x_center,
                    y=y_center,
                    width=bbox.width,
                    height=bbox.height,
                    confidence=score,
                    class_name="face",
                    landmarks=landmarks,
                ),
                yaw=gaze[0],
                pitch=gaze[1],
            )
            predictions.append(prediction)

        response = GazeDetectionInferenceResponse(
            predictions=predictions,
            time=time_total,
            time_face_det=time_face_det,
            time_gaze_det=time_gaze_det,
        )
        return response

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: The list of file names.
        """
        return [
            "mediapipe_face_detector.tflite",
            "L2CSNet_gaze360_resnet50_90bins.onnx",
        ]

    def infer_from_request(
        self, request: GazeDetectionInferenceRequest
    ) -> List[GazeDetectionInferenceResponse]:
        """Detect faces and gazes in image(s).

        Args:
            request (GazeDetectionInferenceRequest): The request object containing the image.

        Returns:
            List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.
        """
        if isinstance(request.image, list):
            if len(request.image) > GAZE_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be inferred with gaze detection at one time is {GAZE_MAX_BATCH_SIZE}"
                )
            imgs = request.image
        else:
            imgs = [request.image]

        time_total = perf_counter()

        # load pil images
        num_img = len(imgs)
        np_imgs = [load_image_rgb(img) for img in imgs]

        # face detection
        # TODO: face detection for batch
        time_face_det = perf_counter()
        faces = []
        for np_img in np_imgs:
            if request.do_run_face_detection:
                mp_img = mp.Image(
                    image_format=mp.ImageFormat.SRGB, data=np_img.astype(np.uint8)
                )
                faces_per_img = self.face_detector.detect(mp_img).detections
            else:
                faces_per_img = [
                    Detection(
                        bounding_box=BoundingBox(
                            origin_x=0,
                            origin_y=0,
                            width=np_img.shape[1],
                            height=np_img.shape[0],
                        ),
                        categories=[Category(score=1.0, category_name="face")],
                        keypoints=[],
                    )
                ]
            faces.append(faces_per_img)
        time_face_det = (perf_counter() - time_face_det) / num_img

        # gaze detection
        time_gaze_det = perf_counter()
        face_imgs = []
        for i, np_img in enumerate(np_imgs):
            if request.do_run_face_detection:
                face_imgs.extend(
                    [self._crop_face_img(np_img, face) for face in faces[i]]
                )
            else:
                face_imgs.append(cv2.resize(np_img, (224, 224)))
        gazes = self._detect_gaze(face_imgs)
        time_gaze_det = (perf_counter() - time_gaze_det) / num_img

        time_total = (perf_counter() - time_total) / num_img

        # prepare response
        response = []
        idx_gaze = 0
        for i in range(len(np_imgs)):
            imgH, imgW, _ = np_imgs[i].shape
            faces_per_img = faces[i]
            gazes_per_img = gazes[idx_gaze : idx_gaze + len(faces_per_img)]
            response.append(
                self._make_response(
                    faces_per_img, gazes_per_img, imgW, imgH, time_total
                )
            )
            idx_gaze += len(faces_per_img)

        return response
Functions
__init__
__init__(*args, **kwargs)

Initializes the Gaze with the given arguments and keyword arguments.

Source code in inference/models/gaze/gaze.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __init__(self, *args, **kwargs):
    """Initializes the Gaze with the given arguments and keyword arguments."""

    t1 = perf_counter()
    super().__init__(*args, **kwargs)
    # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical.
    self.log("Creating inference sessions")

    # TODO: convert face detector (TensorflowLite) to ONNX model

    self.gaze_onnx_session = onnxruntime.InferenceSession(
        self.cache_file("L2CSNet_gaze360_resnet50_90bins.onnx"),
        providers=[
            (
                "TensorrtExecutionProvider",
                {
                    "trt_engine_cache_enable": True,
                    "trt_engine_cache_path": TENSORRT_CACHE_PATH,
                },
            ),
            "CUDAExecutionProvider",
            "OpenVINOExecutionProvider",
            "CPUExecutionProvider",
        ],
    )
    self._gaze_session_lock = Lock()

    if REQUIRED_ONNX_PROVIDERS:
        available_providers = onnxruntime.get_available_providers()
        for provider in REQUIRED_ONNX_PROVIDERS:
            if provider not in available_providers:
                raise OnnxProviderNotAvailable(
                    f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device."
                )

    # init face detector
    self.face_detector = mp.tasks.vision.FaceDetector.create_from_options(
        mp.tasks.vision.FaceDetectorOptions(
            base_options=mp.tasks.BaseOptions(
                model_asset_path=self.cache_file("mediapipe_face_detector.tflite")
            ),
            running_mode=mp.tasks.vision.RunningMode.IMAGE,
        )
    )

    # additional settings for gaze detection
    self._gaze_transformations = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize(448),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )
    self.task_type = "gaze-detection"
    self.log(f"GAZE model loaded in {perf_counter() - t1:.2f} seconds")
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Returns:

Type Description
List[str]

List[str]: The list of file names.

Source code in inference/models/gaze/gaze.py
209
210
211
212
213
214
215
216
217
218
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference.

    Returns:
        List[str]: The list of file names.
    """
    return [
        "mediapipe_face_detector.tflite",
        "L2CSNet_gaze360_resnet50_90bins.onnx",
    ]
infer_from_request
infer_from_request(request)

Detect faces and gazes in image(s).

Parameters:

Name Type Description Default
request GazeDetectionInferenceRequest

The request object containing the image.

required

Returns:

Type Description
List[GazeDetectionInferenceResponse]

List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.

Source code in inference/models/gaze/gaze.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def infer_from_request(
    self, request: GazeDetectionInferenceRequest
) -> List[GazeDetectionInferenceResponse]:
    """Detect faces and gazes in image(s).

    Args:
        request (GazeDetectionInferenceRequest): The request object containing the image.

    Returns:
        List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.
    """
    if isinstance(request.image, list):
        if len(request.image) > GAZE_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be inferred with gaze detection at one time is {GAZE_MAX_BATCH_SIZE}"
            )
        imgs = request.image
    else:
        imgs = [request.image]

    time_total = perf_counter()

    # load pil images
    num_img = len(imgs)
    np_imgs = [load_image_rgb(img) for img in imgs]

    # face detection
    # TODO: face detection for batch
    time_face_det = perf_counter()
    faces = []
    for np_img in np_imgs:
        if request.do_run_face_detection:
            mp_img = mp.Image(
                image_format=mp.ImageFormat.SRGB, data=np_img.astype(np.uint8)
            )
            faces_per_img = self.face_detector.detect(mp_img).detections
        else:
            faces_per_img = [
                Detection(
                    bounding_box=BoundingBox(
                        origin_x=0,
                        origin_y=0,
                        width=np_img.shape[1],
                        height=np_img.shape[0],
                    ),
                    categories=[Category(score=1.0, category_name="face")],
                    keypoints=[],
                )
            ]
        faces.append(faces_per_img)
    time_face_det = (perf_counter() - time_face_det) / num_img

    # gaze detection
    time_gaze_det = perf_counter()
    face_imgs = []
    for i, np_img in enumerate(np_imgs):
        if request.do_run_face_detection:
            face_imgs.extend(
                [self._crop_face_img(np_img, face) for face in faces[i]]
            )
        else:
            face_imgs.append(cv2.resize(np_img, (224, 224)))
    gazes = self._detect_gaze(face_imgs)
    time_gaze_det = (perf_counter() - time_gaze_det) / num_img

    time_total = (perf_counter() - time_total) / num_img

    # prepare response
    response = []
    idx_gaze = 0
    for i in range(len(np_imgs)):
        imgH, imgW, _ = np_imgs[i].shape
        faces_per_img = faces[i]
        gazes_per_img = gazes[idx_gaze : idx_gaze + len(faces_per_img)]
        response.append(
            self._make_response(
                faces_per_img, gazes_per_img, imgW, imgH, time_total
            )
        )
        idx_gaze += len(faces_per_img)

    return response

L2C2Wrapper

Bases: L2CS

Roboflow L2CS Gaze detection model.

This class is responsible for converting L2CS model to ONNX model. It is ONLY intended for internal usage.

Workflow

After training a L2CS model, create an instance of this wrapper class. Load the trained weights file, and save it as ONNX model.

Source code in inference/models/gaze/gaze.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class L2C2Wrapper(L2CS):
    """Roboflow L2CS Gaze detection model.

    This class is responsible for converting L2CS model to ONNX model.
    It is ONLY intended for internal usage.

    Workflow:
        After training a L2CS model, create an instance of this wrapper class.
        Load the trained weights file, and save it as ONNX model.
    """

    def __init__(self):
        self.device = torch.device("cpu")
        self.num_bins = 90
        super().__init__(
            torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], self.num_bins
        )
        self._gaze_softmax = nn.Softmax(dim=1)
        self._gaze_idx_tensor = torch.FloatTensor([i for i in range(90)]).to(
            self.device
        )

    def forward(self, x):
        idx_tensor = torch.stack(
            [self._gaze_idx_tensor for i in range(x.shape[0])], dim=0
        )
        gaze_yaw, gaze_pitch = super().forward(x)

        yaw_predicted = self._gaze_softmax(gaze_yaw)
        yaw_radian = (
            (torch.sum(yaw_predicted * idx_tensor, dim=1) * 4 - 180) * np.pi / 180
        )

        pitch_predicted = self._gaze_softmax(gaze_pitch)
        pitch_radian = (
            (torch.sum(pitch_predicted * idx_tensor, dim=1) * 4 - 180) * np.pi / 180
        )

        return yaw_radian, pitch_radian

    def load_L2CS_model(
        self,
        file_path=f"{MODEL_CACHE_DIR}/gaze/L2CS/L2CSNet_gaze360_resnet50_90bins.pkl",
    ):
        super().load_state_dict(torch.load(file_path, map_location=self.device))
        super().to(self.device)

    def saveas_ONNX_model(
        self,
        file_path=f"{MODEL_CACHE_DIR}/gaze/L2CS/L2CSNet_gaze360_resnet50_90bins.onnx",
    ):
        dummy_input = torch.randn(1, 3, 448, 448)
        dynamic_axes = {
            "input": {0: "batch_size"},
            "output_yaw": {0: "batch_size"},
            "output_pitch": {0: "batch_size"},
        }
        torch.onnx.export(
            self,
            dummy_input,
            file_path,
            input_names=["input"],
            output_names=["output_yaw", "output_pitch"],
            dynamic_axes=dynamic_axes,
            verbose=False,
        )

inference.models.gaze.gaze_inference_models

Classes

InferenceModelsGazeAdapter

Bases: Model

Roboflow ONNX Gaze model.

This class is responsible for handling the ONNX Gaze model, including loading the model, preprocessing the input, and performing inference.

Attributes:

Name Type Description
gaze_onnx_session InferenceSession

ONNX Runtime session for gaze detection inference.

Source code in inference/models/gaze/gaze_inference_models.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class InferenceModelsGazeAdapter(Model):
    """Roboflow ONNX Gaze model.

    This class is responsible for handling the ONNX Gaze model, including
    loading the model, preprocessing the input, and performing inference.

    Attributes:
        gaze_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for gaze detection inference.
    """

    def __init__(self, *args, api_key: str = None, **kwargs):
        """Initializes the Gaze with the given arguments and keyword arguments."""
        super().__init__()
        self.task_type = "gaze-detection"
        self.api_key = api_key if api_key else API_KEY

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._pipeline: FaceAndGazeDetectionMPAndL2CS = (
            AutoModelPipeline.from_pretrained(
                "face-and-gaze-detection",
                api_key=self.api_key,
                extra_weights_provider_headers=extra_weights_provider_headers,
                allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
                allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
                backend=backend,
            )
        )

    def infer_from_request(
        self, request: GazeDetectionInferenceRequest
    ) -> List[GazeDetectionInferenceResponse]:
        """Detect faces and gazes in image(s).

        Args:
            request (GazeDetectionInferenceRequest): The request object containing the image.

        Returns:
            List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.
        """
        timer_start = perf_counter()
        if isinstance(request.image, list):
            if len(request.image) > GAZE_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be inferred with gaze detection at one time is {GAZE_MAX_BATCH_SIZE}"
                )
            imgs = request.image
        else:
            imgs = [request.image]
        np_imgs = [load_image_bgr(img) for img in imgs]
        avg_image_loading_time = (perf_counter() - timer_start) / len(np_imgs)
        if not request.do_run_face_detection:
            predictions_time_start = perf_counter()
            gaze_detections = self._pipeline._gaze_detector.infer(images=np_imgs)
            avg_image_prediction_time = (perf_counter() - predictions_time_start) / len(
                np_imgs
            )
            predictions = []
            for i, image in enumerate(np_imgs):
                image_yaw = gaze_detections.yaw[i].item()
                image_pitch = gaze_detections.pitch[i].item()
                faces = [
                    Detection(
                        bounding_box=BoundingBox(
                            origin_x=0,
                            origin_y=0,
                            width=image.shape[1],
                            height=image.shape[0],
                        ),
                        categories=[Category(score=1.0, category_name="face")],
                        keypoints=[],
                    )
                ]
                gazes = [(image_yaw, image_pitch)]
                image_predictions = self._make_response(
                    faces=faces,
                    gazes=gazes,
                    imgW=image.shape[1],
                    imgH=image.shape[0],
                    time_total=avg_image_prediction_time + avg_image_loading_time,
                    time_face_det=0,
                    time_gaze_det=avg_image_prediction_time,
                )
                predictions.append(image_predictions)
            return predictions

        predictions_time_start = perf_counter()
        landmarks, faces, gazes = self._pipeline(images=np_imgs)
        # prepare response
        avg_image_prediction_time = (perf_counter() - predictions_time_start) / len(
            np_imgs
        )
        response = []
        for i in range(len(np_imgs)):
            imgH, imgW, _ = np_imgs[i].shape
            faces_per_img = faces[i]
            landmarks_per_img = landmarks[i]
            gazes_per_img = gazes[i]
            processed_faces_for_image = []
            processed_gazes_for_image = []
            for detection_id in range(faces_per_img.xyxy.shape[0]):
                min_x, min_y, max_x, max_y = faces_per_img.xyxy[detection_id].tolist()
                width = max_x - min_x
                height = max_y - min_y
                score = faces_per_img.confidence[detection_id].item()
                detection_keypoints = landmarks_per_img.xy[detection_id].tolist()
                processed_keypoints = []
                for x, y in detection_keypoints:
                    processed_keypoints.append(
                        NormalizedKeypoint(x=x / imgW, y=y / imgH)
                    )
                face_detection_mp = Detection(
                    bounding_box=BoundingBox(
                        origin_x=min_x,
                        origin_y=min_y,
                        width=width,
                        height=height,
                    ),
                    categories=[Category(score=score, category_name="face")],
                    keypoints=processed_keypoints,
                )
                processed_faces_for_image.append(face_detection_mp)
                if gazes_per_img is None:
                    processed_gazes_for_image.append(None)
                else:
                    processed_gazes_for_image.append(
                        (
                            gazes_per_img.yaw[detection_id].item(),
                            gazes_per_img.pitch[detection_id].item(),
                        )
                    )
            response.append(
                self._make_response(
                    processed_faces_for_image,
                    processed_gazes_for_image,
                    imgW,
                    imgH,
                    avg_image_prediction_time + avg_image_loading_time,
                )
            )
        return response

    def _make_response(
        self,
        faces: List[Detection],
        gazes: List[Optional[Tuple[float, float]]],
        imgW: int,
        imgH: int,
        time_total: float,
        time_face_det: float = None,
        time_gaze_det: float = None,
    ) -> GazeDetectionInferenceResponse:
        """Prepare response object from detected faces and corresponding gazes.

        Args:
            faces (List[Detection]): The detected faces.
            gazes (List[tuple(float, float)]): The detected gazes (yaw, pitch).
            imgW (int): The width (px) of original image.
            imgH (int): The height (px) of original image.
            time_total (float): The processing time.
            time_face_det (float): The processing time.
            time_gaze_det (float): The processing time.

        Returns:
            GazeDetectionInferenceResponse: The response object including the detected faces and gazes info.
        """
        predictions = []
        for face, gaze in zip(faces, gazes):
            landmarks = []
            for keypoint in face.keypoints:
                x = min(max(int(keypoint.x * imgW), 0), imgW - 1)
                y = min(max(int(keypoint.y * imgH), 0), imgH - 1)
                landmarks.append(Point(x=x, y=y))

            bbox = face.bounding_box
            x_center = bbox.origin_x + bbox.width / 2
            y_center = bbox.origin_y + bbox.height / 2
            score = face.categories[0].score

            prediction = GazeDetectionPrediction(
                face=FaceDetectionPrediction(
                    **dict(
                        x=x_center,
                        y=y_center,
                        width=bbox.width,
                        height=bbox.height,
                        confidence=score,
                        class_name="face",
                        landmarks=landmarks,
                    )
                ),
                yaw=gaze[0] if gaze is not None else None,
                pitch=gaze[1] if gaze is not None else None,
            )
            predictions.append(prediction)

        return GazeDetectionInferenceResponse(
            predictions=predictions,
            time=time_total,
            time_face_det=time_face_det,
            time_gaze_det=time_gaze_det,
        )
Functions
__init__
__init__(*args, api_key=None, **kwargs)

Initializes the Gaze with the given arguments and keyword arguments.

Source code in inference/models/gaze/gaze_inference_models.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, *args, api_key: str = None, **kwargs):
    """Initializes the Gaze with the given arguments and keyword arguments."""
    super().__init__()
    self.task_type = "gaze-detection"
    self.api_key = api_key if api_key else API_KEY

    extra_weights_provider_headers = get_extra_weights_provider_headers(
        countinference=kwargs.get("countinference"),
        service_secret=kwargs.get("service_secret"),
    )
    backend = list(
        VALID_INFERENCE_MODELS_BACKENDS.difference(
            DISABLED_INFERENCE_MODELS_BACKENDS
        )
    )
    self._pipeline: FaceAndGazeDetectionMPAndL2CS = (
        AutoModelPipeline.from_pretrained(
            "face-and-gaze-detection",
            api_key=self.api_key,
            extra_weights_provider_headers=extra_weights_provider_headers,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            backend=backend,
        )
    )
infer_from_request
infer_from_request(request)

Detect faces and gazes in image(s).

Parameters:

Name Type Description Default
request GazeDetectionInferenceRequest

The request object containing the image.

required

Returns:

Type Description
List[GazeDetectionInferenceResponse]

List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.

Source code in inference/models/gaze/gaze_inference_models.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def infer_from_request(
    self, request: GazeDetectionInferenceRequest
) -> List[GazeDetectionInferenceResponse]:
    """Detect faces and gazes in image(s).

    Args:
        request (GazeDetectionInferenceRequest): The request object containing the image.

    Returns:
        List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes.
    """
    timer_start = perf_counter()
    if isinstance(request.image, list):
        if len(request.image) > GAZE_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be inferred with gaze detection at one time is {GAZE_MAX_BATCH_SIZE}"
            )
        imgs = request.image
    else:
        imgs = [request.image]
    np_imgs = [load_image_bgr(img) for img in imgs]
    avg_image_loading_time = (perf_counter() - timer_start) / len(np_imgs)
    if not request.do_run_face_detection:
        predictions_time_start = perf_counter()
        gaze_detections = self._pipeline._gaze_detector.infer(images=np_imgs)
        avg_image_prediction_time = (perf_counter() - predictions_time_start) / len(
            np_imgs
        )
        predictions = []
        for i, image in enumerate(np_imgs):
            image_yaw = gaze_detections.yaw[i].item()
            image_pitch = gaze_detections.pitch[i].item()
            faces = [
                Detection(
                    bounding_box=BoundingBox(
                        origin_x=0,
                        origin_y=0,
                        width=image.shape[1],
                        height=image.shape[0],
                    ),
                    categories=[Category(score=1.0, category_name="face")],
                    keypoints=[],
                )
            ]
            gazes = [(image_yaw, image_pitch)]
            image_predictions = self._make_response(
                faces=faces,
                gazes=gazes,
                imgW=image.shape[1],
                imgH=image.shape[0],
                time_total=avg_image_prediction_time + avg_image_loading_time,
                time_face_det=0,
                time_gaze_det=avg_image_prediction_time,
            )
            predictions.append(image_predictions)
        return predictions

    predictions_time_start = perf_counter()
    landmarks, faces, gazes = self._pipeline(images=np_imgs)
    # prepare response
    avg_image_prediction_time = (perf_counter() - predictions_time_start) / len(
        np_imgs
    )
    response = []
    for i in range(len(np_imgs)):
        imgH, imgW, _ = np_imgs[i].shape
        faces_per_img = faces[i]
        landmarks_per_img = landmarks[i]
        gazes_per_img = gazes[i]
        processed_faces_for_image = []
        processed_gazes_for_image = []
        for detection_id in range(faces_per_img.xyxy.shape[0]):
            min_x, min_y, max_x, max_y = faces_per_img.xyxy[detection_id].tolist()
            width = max_x - min_x
            height = max_y - min_y
            score = faces_per_img.confidence[detection_id].item()
            detection_keypoints = landmarks_per_img.xy[detection_id].tolist()
            processed_keypoints = []
            for x, y in detection_keypoints:
                processed_keypoints.append(
                    NormalizedKeypoint(x=x / imgW, y=y / imgH)
                )
            face_detection_mp = Detection(
                bounding_box=BoundingBox(
                    origin_x=min_x,
                    origin_y=min_y,
                    width=width,
                    height=height,
                ),
                categories=[Category(score=score, category_name="face")],
                keypoints=processed_keypoints,
            )
            processed_faces_for_image.append(face_detection_mp)
            if gazes_per_img is None:
                processed_gazes_for_image.append(None)
            else:
                processed_gazes_for_image.append(
                    (
                        gazes_per_img.yaw[detection_id].item(),
                        gazes_per_img.pitch[detection_id].item(),
                    )
                )
        response.append(
            self._make_response(
                processed_faces_for_image,
                processed_gazes_for_image,
                imgW,
                imgH,
                avg_image_prediction_time + avg_image_loading_time,
            )
        )
    return response

inference.models.gaze.l2cs

Classes

L2CS

Bases: Module

L2CS Gaze Detection Model.

This class is responsible for performing gaze detection using the L2CS-Net model. Ref: https://github.com/Ahmednull/L2CS-Net

Methods:

Name Description
forward

Performs inference on the given image.

Source code in inference/models/gaze/l2cs.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class L2CS(nn.Module):
    """L2CS Gaze Detection Model.

    This class is responsible for performing gaze detection using the L2CS-Net model.
    Ref: https://github.com/Ahmednull/L2CS-Net

    Methods:
        forward: Performs inference on the given image.
    """

    def __init__(self, block, layers, num_bins):
        self.inplanes = 64
        super(L2CS, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
        self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)

        # Vestigial layer from previous experiments
        self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        # gaze
        pre_yaw_gaze = self.fc_yaw_gaze(x)
        pre_pitch_gaze = self.fc_pitch_gaze(x)
        return pre_yaw_gaze, pre_pitch_gaze

models/grounding_dino

inference.models.grounding_dino.grounding_dino

Classes

GroundingDINO

Bases: RoboflowCoreModel

GroundingDINO class for zero-shot object detection.

Attributes:

Name Type Description
model

The GroundingDINO model.

Source code in inference/models/grounding_dino/grounding_dino.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class GroundingDINO(RoboflowCoreModel):
    """GroundingDINO class for zero-shot object detection.

    Attributes:
        model: The GroundingDINO model.
    """

    def __init__(
        self, *args, model_id="grounding_dino/groundingdino_swint_ogc", **kwargs
    ):
        """Initializes the GroundingDINO model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """

        super().__init__(*args, model_id=model_id, **kwargs)

        GROUNDING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)

        GROUNDING_DINO_CONFIG_PATH = os.path.join(
            GROUNDING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
        )

        if not os.path.exists(GROUNDING_DINO_CACHE_DIR):
            os.makedirs(GROUNDING_DINO_CACHE_DIR)

        if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
            url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
            urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)

        self.model = Model(
            model_config_path=GROUNDING_DINO_CONFIG_PATH,
            model_checkpoint_path=os.path.join(
                GROUNDING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
            ),
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
        self.task_type = "object-detection"

    def preproc_image(self, image: Any):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_bgr(image)
        return np_image

    def infer_from_request(
        self,
        request: GroundingDINOInferenceRequest,
    ) -> ObjectDetectionInferenceResponse:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        """
        result = self.infer(**request.dict())
        return result

    def infer(
        self,
        image: InferenceRequestImage,
        text: List[str] = None,
        class_filter: list = None,
        box_threshold=0.5,
        text_threshold=0.5,
        class_agnostic_nms=CLASS_AGNOSTIC_NMS,
        **kwargs
    ):
        """
        Run inference on a provided image.
            - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

        Args:
            request (CVInferenceRequest): The inference request.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            GroundingDINOInferenceRequest: The inference response.
        """
        t1 = perf_counter()
        image = self.preproc_image(image)
        img_dims = image.shape

        detections = self.model.predict_with_classes(
            image=image,
            classes=text,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
        )

        self.class_names = text

        if class_agnostic_nms:
            detections = detections.with_nms(class_agnostic=True)
        else:
            detections = detections.with_nms()

        xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy]

        t2 = perf_counter() - t1

        responses = ObjectDetectionInferenceResponse(
            predictions=[
                ObjectDetectionPrediction(
                    **{
                        "x": xywh_bboxes[i][0],
                        "y": xywh_bboxes[i][1],
                        "width": xywh_bboxes[i][2],
                        "height": xywh_bboxes[i][3],
                        "confidence": detections.confidence[i],
                        "class": self.class_names[int(detections.class_id[i])],
                        "class_id": int(detections.class_id[i]),
                    }
                )
                for i, pred in enumerate(detections.xyxy)
                if not class_filter
                or self.class_names[int(pred[6])] in class_filter
                and detections.class_id[i] is not None
            ],
            image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
            time=t2,
        )
        return responses

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["groundingdino_swint_ogc.pth"]
Functions
__init__
__init__(
    *args,
    model_id="grounding_dino/groundingdino_swint_ogc",
    **kwargs
)

Initializes the GroundingDINO model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/grounding_dino/grounding_dino.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self, *args, model_id="grounding_dino/groundingdino_swint_ogc", **kwargs
):
    """Initializes the GroundingDINO model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """

    super().__init__(*args, model_id=model_id, **kwargs)

    GROUNDING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)

    GROUNDING_DINO_CONFIG_PATH = os.path.join(
        GROUNDING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
    )

    if not os.path.exists(GROUNDING_DINO_CACHE_DIR):
        os.makedirs(GROUNDING_DINO_CACHE_DIR)

    if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
        url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
        urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)

    self.model = Model(
        model_config_path=GROUNDING_DINO_CONFIG_PATH,
        model_checkpoint_path=os.path.join(
            GROUNDING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
        ),
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    self.task_type = "object-detection"
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["model.pt"].

Source code in inference/models/grounding_dino/grounding_dino.py
192
193
194
195
196
197
198
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["model.pt"].
    """
    return ["groundingdino_swint_ogc.pth"]
infer
infer(
    image,
    text=None,
    class_filter=None,
    box_threshold=0.5,
    text_threshold=0.5,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    **kwargs
)

Run inference on a provided image. - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Parameters:

Name Type Description Default
request CVInferenceRequest

The inference request.

required
class_filter Optional[List[str]]

A list of class names to filter, if provided.

None

Returns:

Name Type Description
GroundingDINOInferenceRequest

The inference response.

Source code in inference/models/grounding_dino/grounding_dino.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def infer(
    self,
    image: InferenceRequestImage,
    text: List[str] = None,
    class_filter: list = None,
    box_threshold=0.5,
    text_threshold=0.5,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    **kwargs
):
    """
    Run inference on a provided image.
        - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

    Args:
        request (CVInferenceRequest): The inference request.
        class_filter (Optional[List[str]]): A list of class names to filter, if provided.

    Returns:
        GroundingDINOInferenceRequest: The inference response.
    """
    t1 = perf_counter()
    image = self.preproc_image(image)
    img_dims = image.shape

    detections = self.model.predict_with_classes(
        image=image,
        classes=text,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
    )

    self.class_names = text

    if class_agnostic_nms:
        detections = detections.with_nms(class_agnostic=True)
    else:
        detections = detections.with_nms()

    xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy]

    t2 = perf_counter() - t1

    responses = ObjectDetectionInferenceResponse(
        predictions=[
            ObjectDetectionPrediction(
                **{
                    "x": xywh_bboxes[i][0],
                    "y": xywh_bboxes[i][1],
                    "width": xywh_bboxes[i][2],
                    "height": xywh_bboxes[i][3],
                    "confidence": detections.confidence[i],
                    "class": self.class_names[int(detections.class_id[i])],
                    "class_id": int(detections.class_id[i]),
                }
            )
            for i, pred in enumerate(detections.xyxy)
            if not class_filter
            or self.class_names[int(pred[6])] in class_filter
            and detections.class_id[i] is not None
        ],
        image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
        time=t2,
    )
    return responses
infer_from_request
infer_from_request(request)

Perform inference based on the details provided in the request, and return the associated responses.

Source code in inference/models/grounding_dino/grounding_dino.py
116
117
118
119
120
121
122
123
124
def infer_from_request(
    self,
    request: GroundingDINOInferenceRequest,
) -> ObjectDetectionInferenceResponse:
    """
    Perform inference based on the details provided in the request, and return the associated responses.
    """
    result = self.infer(**request.dict())
    return result
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/grounding_dino/grounding_dino.py
104
105
106
107
108
109
110
111
112
113
114
def preproc_image(self, image: Any):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    np_image = load_image_bgr(image)
    return np_image

Functions

inference.models.grounding_dino.grounding_dino_inference_models

Classes

InferenceModelsGroundingDINOAdapter

Bases: Model

GroundingDINO class for zero-shot object detection.

Attributes:

Name Type Description
model

The GroundingDINO model.

Source code in inference/models/grounding_dino/grounding_dino_inference_models.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class InferenceModelsGroundingDINOAdapter(Model):
    """GroundingDINO class for zero-shot object detection.

    Attributes:
        model: The GroundingDINO model.
    """

    def __init__(
        self,
        model_id: str = "grounding_dino/groundingdino_swint_ogc",
        api_key: str = None,
        **kwargs
    ):
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY

        self.task_type = "object-detection"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: GroundingDinoForObjectDetectionTorch = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )

    def preproc_image(self, image: Any):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        return load_image_bgr(image)

    def infer_from_request(
        self,
        request: GroundingDINOInferenceRequest,
    ) -> ObjectDetectionInferenceResponse:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        """
        result = self.infer(**request.dict())
        return result

    def infer(
        self,
        image: InferenceRequestImage,
        text: List[str] = None,
        class_filter: list = None,
        box_threshold=0.5,
        text_threshold=0.5,
        class_agnostic_nms=CLASS_AGNOSTIC_NMS,
        **kwargs
    ):
        """
        Run inference on a provided image.
            - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

        Args:
            request (CVInferenceRequest): The inference request.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            GroundingDINOInferenceRequest: The inference response.
        """
        if text is None:
            raise ValueError(
                "`text` parameter is required for GroundingDINO inference."
            )
        t1 = perf_counter()
        image = self.preproc_image(image)
        img_dims = image.shape

        detections = self._model.infer(
            images=image,
            classes=text,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            class_agnostic_nms=class_agnostic_nms,
        )[0]
        t2 = perf_counter() - t1
        predictions_for_image = []
        for instance_id in range(detections.xyxy.shape[0]):
            x_min, y_min, x_max, y_max = detections.xyxy[instance_id].tolist()
            width = x_max - x_min
            height = y_max - y_min
            center_x = (x_min + x_max) / 2
            center_y = (y_min + y_max) / 2
            class_id = detections.class_id[instance_id].item()
            confidence = detections.confidence[instance_id].item()
            class_name = text[class_id]
            if class_filter and class_name not in class_filter:
                continue
            predictions_for_image.append(
                ObjectDetectionPrediction(
                    # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                    **{
                        "x": center_x,
                        "y": center_y,
                        "width": width,
                        "height": height,
                        "confidence": confidence,
                        "class": text[class_id],
                        "class_id": class_id,  # you can only prompt for one object at once
                    }
                )
            )
        return ObjectDetectionInferenceResponse(
            predictions=predictions_for_image,
            image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
            time=t2,
        )
Functions
infer
infer(
    image,
    text=None,
    class_filter=None,
    box_threshold=0.5,
    text_threshold=0.5,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    **kwargs
)

Run inference on a provided image. - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

Parameters:

Name Type Description Default
request CVInferenceRequest

The inference request.

required
class_filter Optional[List[str]]

A list of class names to filter, if provided.

None

Returns:

Name Type Description
GroundingDINOInferenceRequest

The inference response.

Source code in inference/models/grounding_dino/grounding_dino_inference_models.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def infer(
    self,
    image: InferenceRequestImage,
    text: List[str] = None,
    class_filter: list = None,
    box_threshold=0.5,
    text_threshold=0.5,
    class_agnostic_nms=CLASS_AGNOSTIC_NMS,
    **kwargs
):
    """
    Run inference on a provided image.
        - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

    Args:
        request (CVInferenceRequest): The inference request.
        class_filter (Optional[List[str]]): A list of class names to filter, if provided.

    Returns:
        GroundingDINOInferenceRequest: The inference response.
    """
    if text is None:
        raise ValueError(
            "`text` parameter is required for GroundingDINO inference."
        )
    t1 = perf_counter()
    image = self.preproc_image(image)
    img_dims = image.shape

    detections = self._model.infer(
        images=image,
        classes=text,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        class_agnostic_nms=class_agnostic_nms,
    )[0]
    t2 = perf_counter() - t1
    predictions_for_image = []
    for instance_id in range(detections.xyxy.shape[0]):
        x_min, y_min, x_max, y_max = detections.xyxy[instance_id].tolist()
        width = x_max - x_min
        height = y_max - y_min
        center_x = (x_min + x_max) / 2
        center_y = (y_min + y_max) / 2
        class_id = detections.class_id[instance_id].item()
        confidence = detections.confidence[instance_id].item()
        class_name = text[class_id]
        if class_filter and class_name not in class_filter:
            continue
        predictions_for_image.append(
            ObjectDetectionPrediction(
                # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                **{
                    "x": center_x,
                    "y": center_y,
                    "width": width,
                    "height": height,
                    "confidence": confidence,
                    "class": text[class_id],
                    "class_id": class_id,  # you can only prompt for one object at once
                }
            )
        )
    return ObjectDetectionInferenceResponse(
        predictions=predictions_for_image,
        image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
        time=t2,
    )
infer_from_request
infer_from_request(request)

Perform inference based on the details provided in the request, and return the associated responses.

Source code in inference/models/grounding_dino/grounding_dino_inference_models.py
79
80
81
82
83
84
85
86
87
def infer_from_request(
    self,
    request: GroundingDINOInferenceRequest,
) -> ObjectDetectionInferenceResponse:
    """
    Perform inference based on the details provided in the request, and return the associated responses.
    """
    result = self.infer(**request.dict())
    return result
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/grounding_dino/grounding_dino_inference_models.py
68
69
70
71
72
73
74
75
76
77
def preproc_image(self, image: Any):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    return load_image_bgr(image)

Functions

models/owlv2

inference.models.owlv2.owlv2

Classes

OwlV2

Bases: RoboflowInferenceModel

Source code in inference/models/owlv2/owlv2.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
class OwlV2(RoboflowInferenceModel):
    task_type = "object-detection"
    box_format = "xywh"

    def __init__(self, model_id=f"owlv2/{OWLV2_VERSION_ID}", *args, **kwargs):
        super().__init__(model_id, *args, **kwargs)
        # TODO: owlv2 makes use of version_id - version_id is being dropped so this class needs to be refactored

        self.owlv2_lock = RLock()

        if self.version_id is None:
            owlv2_model_id_chunks = model_id.split("/")
            if len(owlv2_model_id_chunks) != 2:
                raise InvalidModelIDError("Model ID: `%s` is invalid.", model_id)
            self.dataset_id = owlv2_model_id_chunks[0]
            self.version_id = owlv2_model_id_chunks[1]
        hf_id = os.path.join("google", self.version_id)
        processor = Owlv2Processor.from_pretrained(hf_id)
        self.image_size = tuple(processor.image_processor.size.values())
        self.image_mean = torch.tensor(
            processor.image_processor.image_mean, device=DEVICE
        ).view(1, 3, 1, 1)
        self.image_std = torch.tensor(
            processor.image_processor.image_std, device=DEVICE
        ).view(1, 3, 1, 1)
        self.model = Owlv2Singleton(hf_id).model
        self.reset_cache()

    def reset_cache(self):
        # each entry should be on the order of 300*4KB, so 1000 is 400MB of CUDA memory
        self.image_embed_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
        # no need for limit here, as we're only storing on CPU
        self.cpu_image_embed_cache = LimitedSizeDict(
            size_limit=CPU_IMAGE_EMBED_CACHE_SIZE
        )
        # each entry should be on the order of 10 bytes, so 1000 is 10KB
        self.image_size_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
        # entry size will vary depending on the number of samples, but 10 should be safe
        self.class_embeddings_cache = LimitedSizeDict(size_limit=OWLV2_MODEL_CACHE_SIZE)

    def draw_predictions(
        self,
        inference_request,
        inference_response,
    ) -> bytes:
        """Draw predictions from an inference response onto the original image provided by an inference request

        Args:
            inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
            inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

        Returns:
            str: A base64 encoded image string
        """
        all_class_names = [x.class_name for x in inference_response.predictions]
        all_class_names = sorted(list(set(all_class_names)))

        return draw_detection_predictions(
            inference_request=inference_request,
            inference_response=inference_response,
            colors={
                class_name: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
                for (i, class_name) in enumerate(all_class_names)
            },
        )

    def download_weights(self) -> None:
        # Download from huggingface
        pass

    def get_image_embeds(self, image_hash: Hash) -> Optional[tuple]:
        image_embed_cache_hit = self.image_embed_cache.get(image_hash)
        if image_embed_cache_hit is not None:
            return image_embed_cache_hit
        cpu_image_embed_cache_hit = self.cpu_image_embed_cache.get(image_hash)
        if cpu_image_embed_cache_hit is not None:
            tensors = tuple(t.to(DEVICE) for t in cpu_image_embed_cache_hit)
            return tensors
        return None

    def compute_image_size(
        self, image: Union[np.ndarray, LazyImageRetrievalWrapper]
    ) -> Tuple[int, int]:
        if isinstance(image, LazyImageRetrievalWrapper):
            image_size = self.image_size_cache.get(image.image_hash)
            if image_size is None:
                np_img = image.image_as_numpy
                image_size = np_img.shape[:2][::-1]
                with self.owlv2_lock:
                    self.image_size_cache[image.image_hash] = image_size
            return image_size
        else:
            return image.shape[:2][::-1]

    @torch.no_grad()
    def embed_image(
        self, image: Union[np.ndarray, LazyImageRetrievalWrapper]
    ) -> Tuple[Hash, tuple]:
        if isinstance(image, LazyImageRetrievalWrapper):
            image_hash = image.image_hash
        else:
            image_hash = hash_function(image.tobytes())

        image_embeds = self.get_image_embeds(image_hash)
        if image_embeds is not None:
            return image_hash, image_embeds

        np_image = (
            image.image_as_numpy
            if isinstance(image, LazyImageRetrievalWrapper)
            else image
        )
        pixel_values = preprocess_image(
            np_image, self.image_size, self.image_mean, self.image_std
        )

        # torch 2.4 lets you use "cuda:0" as device_type
        # but this crashes in 2.3
        # so we parse DEVICE as a string to make it work in both 2.3 and 2.4
        # as we don't know a priori our torch version
        device_str = "cuda" if str(DEVICE).startswith("cuda") else "cpu"
        # we disable autocast on CPU for stability, although it's possible using bfloat16 would work
        with torch.autocast(
            device_type=device_str, dtype=torch.float16, enabled=device_str == "cuda"
        ):
            image_embeds, _ = self.model.image_embedder(pixel_values=pixel_values)
            batch_size, h, w, dim = image_embeds.shape
            image_features = image_embeds.reshape(batch_size, h * w, dim)
            objectness = self.model.objectness_predictor(image_features)
            boxes = self.model.box_predictor(image_features, feature_map=image_embeds)
        image_class_embeds = self.model.class_head.dense0(image_features)
        image_class_embeds /= (
            torch.linalg.norm(image_class_embeds, ord=2, dim=-1, keepdim=True) + 1e-6
        )
        logit_shift = self.model.class_head.logit_shift(image_features)
        logit_scale = (
            self.model.class_head.elu(self.model.class_head.logit_scale(image_features))
            + 1
        )
        objectness = objectness.sigmoid()

        objectness, boxes, image_class_embeds, logit_shift, logit_scale = (
            filter_tensors_by_objectness(
                objectness, boxes, image_class_embeds, logit_shift, logit_scale
            )
        )
        image_embeds = (
            objectness,
            boxes,
            image_class_embeds,
            logit_shift,
            logit_scale,
        )
        with self.owlv2_lock:
            self.image_embed_cache[image_hash] = image_embeds

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        if isinstance(image, LazyImageRetrievalWrapper):
            image.unload_numpy_image()  # Clears both _image_as_numpy and image if needed.

        return image_hash, image_embeds

    def get_query_embedding(
        self,
        query_spec: QuerySpecType,
        iou_threshold: float,
        precomputed_embeddings: Optional[Dict[Hash, Tuple[torch.Tensor]]] = None,
    ) -> Optional[torch.Tensor]:
        # NOTE: for now we're handling each image seperately
        query_embeds = []
        if precomputed_embeddings is None:
            precomputed_embeddings = {}
        for image_hash, query_boxes in query_spec.items():
            if image_hash in precomputed_embeddings:
                image_embeds = precomputed_embeddings[image_hash]
            else:
                image_embeds = self.get_image_embeds(image_hash)
            if image_embeds is None:
                raise KeyError("We didn't embed the image first!")
            _objectness, image_boxes, image_class_embeds, _, _ = image_embeds

            query_boxes_tensor = torch.tensor(
                query_boxes, dtype=image_boxes.dtype, device=image_boxes.device
            )
            if image_boxes.numel() == 0 or query_boxes_tensor.numel() == 0:
                continue
            iou, _ = box_iou(
                to_corners(image_boxes), to_corners(query_boxes_tensor)
            )  # 3000, k
            ious, indices = torch.max(iou, dim=0)
            # filter for only iou > 0.4
            iou_mask = ious > iou_threshold
            indices = indices[iou_mask]
            if not indices.numel() > 0:
                continue

            embeds = image_class_embeds[indices]
            query_embeds.append(embeds)
        if not query_embeds:
            return None
        query = torch.cat(query_embeds, dim=0)
        return query

    def infer_from_embed(
        self,
        image_hash: Hash,
        query_embeddings: Dict[str, PosNegDictType],
        confidence: float,
        iou_threshold: float,
        max_detections: int = MAX_DETECTIONS,
        image_embeds: Optional[tuple] = None,
    ) -> List[Dict]:
        if image_embeds is None:
            image_embeds = self.get_image_embeds(image_hash)
        if image_embeds is None:
            raise KeyError("We didn't embed the image first!")
        _, image_boxes, image_class_embeds, _, _ = image_embeds
        class_map, class_names = make_class_map(query_embeddings)
        all_predicted_boxes, all_predicted_classes, all_predicted_scores = [], [], []
        for class_name, pos_neg_embedding_dict in query_embeddings.items():
            boxes, classes, scores = get_class_preds_from_embeds(
                pos_neg_embedding_dict,
                image_class_embeds,
                confidence,
                image_boxes,
                class_map,
                class_name,
                iou_threshold,
            )

            all_predicted_boxes.append(boxes)
            all_predicted_classes.append(classes)
            all_predicted_scores.append(scores)

        if not all_predicted_boxes:
            return []

        all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0)
        all_predicted_classes = torch.cat(all_predicted_classes, dim=0)
        all_predicted_scores = torch.cat(all_predicted_scores, dim=0)

        # run nms on all predictions
        survival_indices = torchvision.ops.nms(
            to_corners(all_predicted_boxes), all_predicted_scores, iou_threshold
        )
        all_predicted_boxes = all_predicted_boxes[survival_indices]
        all_predicted_classes = all_predicted_classes[survival_indices]
        all_predicted_scores = all_predicted_scores[survival_indices]

        if len(all_predicted_boxes) > max_detections:
            all_predicted_boxes = all_predicted_boxes[:max_detections]
            all_predicted_classes = all_predicted_classes[:max_detections]
            all_predicted_scores = all_predicted_scores[:max_detections]

        # move tensors to numpy before returning
        all_predicted_boxes = all_predicted_boxes.cpu().numpy()
        all_predicted_classes = all_predicted_classes.cpu().numpy()
        all_predicted_scores = all_predicted_scores.cpu().numpy()

        return [
            {
                "class_name": class_names[int(c)],
                "x": float(x),
                "y": float(y),
                "w": float(w),
                "h": float(h),
                "confidence": float(score),
            }
            for c, (x, y, w, h), score in zip(
                all_predicted_classes, all_predicted_boxes, all_predicted_scores
            )
        ]

    def infer(
        self,
        image: Any,
        training_data: Dict,
        confidence: float = 0.99,
        iou_threshold: float = 0.3,
        max_detections: int = MAX_DETECTIONS,
        **kwargs,
    ):
        class_embeddings_dict = self.make_class_embeddings_dict(
            training_data, iou_threshold
        )
        return self.infer_from_embedding_dict(
            image,
            class_embeddings_dict,
            confidence,
            iou_threshold,
            max_detections=max_detections,
        )

    def infer_from_embedding_dict(
        self,
        image: Any,
        class_embeddings_dict: Dict[str, PosNegDictType],
        confidence: float,
        iou_threshold: float,
        max_detections: int = MAX_DETECTIONS,
        **kwargs,
    ):
        if not isinstance(image, list):
            images = [image]
        else:
            images = image

        images = [LazyImageRetrievalWrapper(image) for image in images]

        results = []
        image_sizes = []
        for image_wrapper in images:
            # happy path here is that both image size and image embeddings are cached
            # in which case we avoid loading the image at all
            image_size = self.compute_image_size(image_wrapper)
            image_sizes.append(image_size)
            image_hash, image_embeds = self.embed_image(image_wrapper)
            image_wrapper.unload_numpy_image()
            result = self.infer_from_embed(
                image_hash,
                class_embeddings_dict,
                confidence,
                iou_threshold,
                max_detections=max_detections,
                image_embeds=image_embeds,
            )
            results.append(result)
        return self.make_response(
            results, image_sizes, sorted(list(class_embeddings_dict.keys()))
        )

    def make_class_embeddings_dict(
        self,
        training_data: List[Any],
        iou_threshold: float,
        return_image_embeds: bool = False,
    ) -> Dict[str, PosNegDictType]:

        wrapped_training_data = [
            {
                "image": LazyImageRetrievalWrapper(train_image["image"]),
                "boxes": train_image["boxes"],
            }
            for train_image in training_data
        ]

        wrapped_training_data_hash = hash_wrapped_training_data(wrapped_training_data)

        if (
            class_embeddings_dict := self.class_embeddings_cache.get(
                wrapped_training_data_hash
            )
        ) is not None:
            if return_image_embeds:
                # Return a dummy empty dict as the second value
                # or extract it from CPU cache if available
                return_image_embeds_dict = {}
                with self.owlv2_lock:
                    for image_hash, value in self.cpu_image_embed_cache.items():
                        return_image_embeds_dict[image_hash] = value
                return class_embeddings_dict, return_image_embeds_dict
            else:
                return class_embeddings_dict

        class_embeddings_dict = defaultdict(lambda: {"positive": [], "negative": []})

        bool_to_literal = {True: "positive", False: "negative"}
        return_image_embeds_dict = dict()

        for train_image in wrapped_training_data:
            image_size = self.compute_image_size(train_image["image"])
            image_hash, image_embeds = self.embed_image(train_image["image"])
            if return_image_embeds:
                return_image_embeds_dict[image_hash] = tuple(
                    t.to("cpu") for t in image_embeds
                )
            # grab and normalize box prompts for this image
            boxes = train_image["boxes"]
            coords = [[box["x"], box["y"], box["w"], box["h"]] for box in boxes]
            coords = [tuple([c / max(image_size) for c in coord]) for coord in coords]
            classes = [box["cls"] for box in boxes]
            is_positive = [not box["negative"] for box in boxes]
            query_spec = {image_hash: coords}
            precomputed_embeddings = {image_hash: image_embeds}
            # compute the embeddings for the box prompts
            embeddings = self.get_query_embedding(
                query_spec,
                iou_threshold,
                precomputed_embeddings=precomputed_embeddings,
            )

            del train_image

            if embeddings is None:
                continue

            for embedding, class_name, is_pos in zip(embeddings, classes, is_positive):
                class_embeddings_dict[class_name][bool_to_literal[is_pos]].append(
                    embedding
                )
        # Convert lists of embeddings to tensors.
        class_embeddings_dict = {
            k: {
                "positive": torch.stack(v["positive"]) if v["positive"] else None,
                "negative": torch.stack(v["negative"]) if v["negative"] else None,
            }
            for k, v in class_embeddings_dict.items()
        }

        with self.owlv2_lock:
            self.class_embeddings_cache[wrapped_training_data_hash] = (
                class_embeddings_dict
            )
        if return_image_embeds:
            return class_embeddings_dict, return_image_embeds_dict

        return class_embeddings_dict

    def make_response(self, predictions, image_sizes, class_names):
        responses = [
            ObjectDetectionInferenceResponse(
                predictions=[
                    ObjectDetectionPrediction(
                        # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                        **{
                            "x": pred["x"] * max(image_sizes[ind]),
                            "y": pred["y"] * max(image_sizes[ind]),
                            "width": pred["w"] * max(image_sizes[ind]),
                            "height": pred["h"] * max(image_sizes[ind]),
                            "confidence": pred["confidence"],
                            "class": pred["class_name"],
                            "class_id": class_names.index(pred["class_name"]),
                        }
                    )
                    for pred in batch_predictions
                ],
                image=InferenceResponseImage(
                    width=image_sizes[ind][0], height=image_sizes[ind][1]
                ),
            )
            for ind, batch_predictions in enumerate(predictions)
        ]
        return responses
Functions
draw_predictions
draw_predictions(inference_request, inference_response)

Draw predictions from an inference response onto the original image provided by an inference request

Parameters:

Name Type Description Default
inference_request ObjectDetectionInferenceRequest

The inference request containing the image on which to draw predictions

required
inference_response ObjectDetectionInferenceResponse

The inference response containing predictions to be drawn

required

Returns:

Name Type Description
str bytes

A base64 encoded image string

Source code in inference/models/owlv2/owlv2.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
def draw_predictions(
    self,
    inference_request,
    inference_response,
) -> bytes:
    """Draw predictions from an inference response onto the original image provided by an inference request

    Args:
        inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions
        inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn

    Returns:
        str: A base64 encoded image string
    """
    all_class_names = [x.class_name for x in inference_response.predictions]
    all_class_names = sorted(list(set(all_class_names)))

    return draw_detection_predictions(
        inference_request=inference_request,
        inference_response=inference_response,
        colors={
            class_name: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)]
            for (i, class_name) in enumerate(all_class_names)
        },
    )

SerializedOwlV2

Bases: RoboflowInferenceModel

Source code in inference/models/owlv2/owlv2.py
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
class SerializedOwlV2(RoboflowInferenceModel):
    task_type = "object-detection"
    box_format = "xywh"

    # Cache of OwlV2 instances to avoid creating new ones for each serialize_training_data call
    # This improves performance by reusing model instances across serialization operations
    _base_owlv2_instances = {}

    @classmethod
    def get_or_create_owlv2_instance(cls, roboflow_id: str) -> OwlV2:
        """Get an existing OwlV2 instance from cache or create a new one if it doesn't exist.

        Args:
            roboflow_id: The model ID for the OwlV2 model

        Returns:
            An OwlV2 instance
        """
        if roboflow_id in cls._base_owlv2_instances:
            return cls._base_owlv2_instances[roboflow_id]
        else:
            owlv2 = OwlV2(model_id=roboflow_id)
            cls._base_owlv2_instances[roboflow_id] = owlv2
            return owlv2

    @classmethod
    def serialize_training_data(
        cls,
        training_data: List[Any],
        hf_id: str = f"google/{OWLV2_VERSION_ID}",
        iou_threshold: float = 0.3,
        save_dir: str = os.path.join(MODEL_CACHE_DIR, "owl-v2-serialized-data"),
        previous_embeddings_file: str = None,
    ):
        roboflow_id = hf_id.replace("google/", "owlv2/")

        owlv2 = cls.get_or_create_owlv2_instance(roboflow_id)

        if previous_embeddings_file is not None:
            if DEVICE == "cpu":
                model_data = torch.load(
                    previous_embeddings_file, map_location="cpu", weights_only=False
                )
            else:
                model_data = torch.load(previous_embeddings_file, weights_only=False)

            train_data_dict = model_data["train_data_dict"]
            if isinstance(model_data["image_embeds"], LimitedSizeDict):
                owlv2.cpu_image_embed_cache = model_data["image_embeds"]
            else:
                cache = LimitedSizeDict(size_limit=CPU_IMAGE_EMBED_CACHE_SIZE)
                for key, value in model_data["image_embeds"].items():
                    cache[key] = value
                owlv2.cpu_image_embed_cache = cache

        train_data_dict, image_embeds = owlv2.make_class_embeddings_dict(
            training_data, iou_threshold, return_image_embeds=True
        )
        return cls.save_model(
            hf_id, roboflow_id, train_data_dict, image_embeds, save_dir
        )

    @classmethod
    def save_model(
        cls,
        hf_id: str,
        roboflow_id: str,
        train_data_dict: Dict,
        image_embeds: Dict,
        save_dir: str,
    ):
        train_data_dict = {
            "huggingface_id": hf_id,
            "train_data_dict": train_data_dict,
            "class_names": list(train_data_dict.keys()),
            "roboflow_id": roboflow_id,
            "image_embeds": image_embeds,
        }
        train_data_path = os.path.join(save_dir, cls.weights_file_path)
        os.makedirs(save_dir, exist_ok=True)
        torch.save(train_data_dict, train_data_path)
        return train_data_path

    def infer_from_request(
        self,
        request: ObjectDetectionInferenceRequest,
    ) -> Union[
        List[ObjectDetectionInferenceResponse], ObjectDetectionInferenceResponse
    ]:
        return super().infer_from_request(request)

    def __init__(self, model_id, *args, **kwargs):
        super().__init__(model_id, *args, **kwargs)
        self.get_model_artifacts(**kwargs)

    def get_infer_bucket_file_list(self):
        return []

    def download_model_artefacts_from_s3(self):
        raise NotImplementedError("Owlv2 not currently supported on hosted inference")

    def download_model_artifacts_from_roboflow_api(
        self,
        countinference: Optional[bool] = None,
        service_secret: Optional[str] = None,
        **kwargs,
    ):
        logger.info("Downloading OWLv2 model artifacts")

        # Use the same lock file pattern as in clear_cache
        lock_dir = MODEL_CACHE_DIR + "/_file_locks"  # Dedicated lock directory
        os.makedirs(lock_dir, exist_ok=True)  # Ensure lock directory exists.
        lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
        try:
            lock = FileLock(lock_file, timeout=120)  # 120 second timeout for downloads
            with lock:
                if self.version_id is not None:
                    api_data = get_roboflow_model_data(
                        api_key=self.api_key,
                        model_id=self.endpoint,
                        endpoint_type=ModelEndpointType.OWLV2,
                        device_id=self.device_id,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    api_data = api_data["owlv2"]
                    if "model" not in api_data:
                        raise ModelArtefactError(
                            "Could not find `model` key in roboflow API model description response."
                        )
                    logger.info("Downloading OWLv2 model weights for %s", self.endpoint)
                    model_weights_response = get_from_url(
                        api_data["model"], json_response=False
                    )
                else:
                    logger.info("Getting OWLv2 model data for %s", self.endpoint)
                    api_data = get_roboflow_instant_model_data(
                        api_key=self.api_key,
                        model_id=self.endpoint,
                        countinference=countinference,
                        service_secret=service_secret,
                    )
                    if (
                        "modelFiles" not in api_data
                        or "owlv2" not in api_data["modelFiles"]
                        or "model" not in api_data["modelFiles"]["owlv2"]
                    ):
                        raise ModelArtefactError(
                            "Could not find `modelFiles` key or `modelFiles`.`owlv2` or `modelFiles`.`owlv2`.`model` key in roboflow API model description response."
                        )
                    logger.info("Downloading OWLv2 model weights for %s", self.endpoint)
                    model_weights_response = get_from_url(
                        api_data["modelFiles"]["owlv2"]["model"], json_response=False
                    )
                save_bytes_in_cache(
                    content=model_weights_response.content,
                    file=self.weights_file,
                    model_id=self.endpoint,
                )
                logger.info("OWLv2 model weights saved to cache")
        except Exception as e:
            logger.error("Error downloading OWLv2 model artifacts: %s", e)
            raise
        finally:
            try:
                if os.path.exists(lock_file):
                    os.unlink(lock_file)  # Clean up lock file
            except OSError:
                pass  # Best effort cleanup

    def load_model_artifacts_from_cache(self):
        if DEVICE == "cpu":
            self.model_data = torch.load(
                self.cache_file(self.weights_file),
                map_location="cpu",
                weights_only=False,
            )
        else:
            self.model_data = torch.load(
                self.cache_file(self.weights_file), weights_only=False
            )
        self.class_names = self.model_data["class_names"]
        self.train_data_dict = self.model_data["train_data_dict"]
        self.huggingface_id = self.model_data["huggingface_id"]
        self.roboflow_id = self.model_data["roboflow_id"]
        # Use the same cached OwlV2 instance mechanism to avoid creating duplicates
        self.owlv2 = self.__class__.get_or_create_owlv2_instance(self.roboflow_id)
        if isinstance(self.model_data["image_embeds"], LimitedSizeDict):
            self.owlv2.cpu_image_embed_cache = self.model_data["image_embeds"]
        else:
            cache = LimitedSizeDict(size_limit=CPU_IMAGE_EMBED_CACHE_SIZE)
            for key, value in self.model_data["image_embeds"].items():
                cache[key] = value
            self.owlv2.cpu_image_embed_cache = cache

    weights_file_path = "weights.pt"

    @property
    def weights_file(self):
        return self.weights_file_path

    def infer(
        self,
        image,
        confidence: float = 0.99,
        iou_threshold: float = 0.3,
        max_detections: int = MAX_DETECTIONS,
        **kwargs,
    ):
        logger.debug("Inferring OWLv2 model")
        result = self.owlv2.infer_from_embedding_dict(
            image,
            self.train_data_dict,
            confidence=confidence,
            iou_threshold=iou_threshold,
            max_detections=max_detections,
            **kwargs,
        )
        logger.debug("OWLv2 model inference complete")
        return result

    def draw_predictions(
        self,
        inference_request: ObjectDetectionInferenceRequest,
        inference_response: ObjectDetectionInferenceResponse,
    ):
        return self.owlv2.draw_predictions(
            inference_request,
            inference_response,
        )

    def save_small_model_without_image_embeds(
        self, save_dir: str = os.path.join(MODEL_CACHE_DIR, "owl-v2-serialized-data")
    ):
        self.owlv2.cpu_image_embed_cache = LimitedSizeDict(
            size_limit=CPU_IMAGE_EMBED_CACHE_SIZE
        )
        return self.save_model(
            self.huggingface_id,
            self.roboflow_id,
            self.train_data_dict,
            {},
            save_dir,
        )
Functions
get_or_create_owlv2_instance classmethod
get_or_create_owlv2_instance(roboflow_id)

Get an existing OwlV2 instance from cache or create a new one if it doesn't exist.

Parameters:

Name Type Description Default
roboflow_id str

The model ID for the OwlV2 model

required

Returns:

Type Description
OwlV2

An OwlV2 instance

Source code in inference/models/owlv2/owlv2.py
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
@classmethod
def get_or_create_owlv2_instance(cls, roboflow_id: str) -> OwlV2:
    """Get an existing OwlV2 instance from cache or create a new one if it doesn't exist.

    Args:
        roboflow_id: The model ID for the OwlV2 model

    Returns:
        An OwlV2 instance
    """
    if roboflow_id in cls._base_owlv2_instances:
        return cls._base_owlv2_instances[roboflow_id]
    else:
        owlv2 = OwlV2(model_id=roboflow_id)
        cls._base_owlv2_instances[roboflow_id] = owlv2
        return owlv2

Functions

preprocess_image

preprocess_image(
    np_image, image_size, image_mean, image_std
)

Preprocess an image for OWLv2 by resizing, normalizing, and padding it. This is much faster than using the Owlv2Processor directly, as we ensure we use GPU if available.

Parameters:

Name Type Description Default
np_image ndarray

The image to preprocess, with shape (H, W, 3)

required
image_size tuple[int, int]

The target size of the image

required
image_mean Tensor

The mean of the image, on DEVICE, with shape (1, 3, 1, 1)

required
image_std Tensor

The standard deviation of the image, on DEVICE, with shape (1, 3, 1, 1)

required

Returns:

Type Description
Tensor

torch.Tensor: The preprocessed image, on DEVICE, with shape (1, 3, H, W)

Source code in inference/models/owlv2/owlv2.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def preprocess_image(
    np_image: np.ndarray,
    image_size: Tuple[int, int],
    image_mean: torch.Tensor,
    image_std: torch.Tensor,
) -> torch.Tensor:
    """Preprocess an image for OWLv2 by resizing, normalizing, and padding it.
    This is much faster than using the Owlv2Processor directly, as we ensure we use GPU if available.

    Args:
        np_image (np.ndarray): The image to preprocess, with shape (H, W, 3)
        image_size (tuple[int, int]): The target size of the image
        image_mean (torch.Tensor): The mean of the image, on DEVICE, with shape (1, 3, 1, 1)
        image_std (torch.Tensor): The standard deviation of the image, on DEVICE, with shape (1, 3, 1, 1)

    Returns:
        torch.Tensor: The preprocessed image, on DEVICE, with shape (1, 3, H, W)
    """
    current_size = np_image.shape[:2]

    r = min(image_size[0] / current_size[0], image_size[1] / current_size[1])
    target_size = (int(r * current_size[0]), int(r * current_size[1]))

    torch_image = (
        torch.tensor(np_image)
        .permute(2, 0, 1)
        .unsqueeze(0)
        .to(DEVICE)
        .to(dtype=torch.float32)
        / 255.0
    )
    torch_image = F.interpolate(
        torch_image, size=target_size, mode="bilinear", align_corners=False
    )

    padded_image_tensor = torch.ones((1, 3, *image_size), device=DEVICE) * 0.5
    padded_image_tensor[:, :, : torch_image.shape[2], : torch_image.shape[3]] = (
        torch_image
    )

    padded_image_tensor = (padded_image_tensor - image_mean) / image_std

    return padded_image_tensor

models/paligemma

inference.models.paligemma.paligemma

Classes

LoRAPaliGemma

Bases: LoRATransformerModel

By using you agree to the terms listed at https://ai.google.dev/gemma/terms

Source code in inference/models/paligemma/paligemma.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class LoRAPaliGemma(LoRATransformerModel):
    """By using you agree to the terms listed at https://ai.google.dev/gemma/terms"""

    generation_includes_input = True
    transformers_class = PaliGemmaForConditionalGeneration
    load_base_from_roboflow = True

    def initialize_model(self, **kwargs):
        import torch

        lora_config = LoraConfig.from_pretrained(self.cache_dir, device_map=DEVICE)
        model_id = lora_config.base_model_name_or_path
        revision = lora_config.revision
        if revision is not None:
            try:
                self.dtype = getattr(torch, revision)
            except AttributeError:
                pass
        if not self.load_base_from_roboflow:
            model_load_id = model_id
            cache_dir = os.path.join(MODEL_CACHE_DIR, "huggingface")
            revision = revision
            token = self.huggingface_token
        else:
            model_load_id = self.get_lora_base_from_roboflow(model_id, revision)
            cache_dir = model_load_id
            revision = None
            token = None
        self.base_model = self.transformers_class.from_pretrained(
            model_load_id,
            revision=revision,
            device_map=DEVICE,
            cache_dir=cache_dir,
            token=token,
            attn_implementation=_get_paligemma_attn_implementation(),
        ).to(self.dtype)
        self.model = (
            PeftModel.from_pretrained(self.base_model, self.cache_dir)
            .eval()
            .to(self.dtype)
        )

        self.model.merge_and_unload()

        self.processor = self.processor_class.from_pretrained(
            model_load_id, revision=revision, cache_dir=cache_dir, token=token
        )

PaliGemma

Bases: TransformerModel

By using you agree to the terms listed at https://ai.google.dev/gemma/terms

Source code in inference/models/paligemma/paligemma.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class PaliGemma(TransformerModel):
    """By using you agree to the terms listed at https://ai.google.dev/gemma/terms"""

    generation_includes_input = True
    transformers_class = PaliGemmaForConditionalGeneration

    def initialize_model(self, **kwargs):
        if not self.load_base_from_roboflow:
            model_id = self.dataset_id
        else:
            model_id = self.cache_dir

        self.model = (
            self.transformers_class.from_pretrained(
                model_id,
                cache_dir=self.cache_dir,
                device_map=DEVICE,
                token=self.huggingface_token,
                torch_dtype=self.default_dtype,
                attn_implementation=_get_paligemma_attn_implementation(),
            )
            .eval()
            .to(self.dtype)
        )

        self.processor = self.processor_class.from_pretrained(
            model_id, cache_dir=self.cache_dir, token=self.huggingface_token
        )

models/perception_encoder

inference.models.perception_encoder.perception_encoder

Classes

PerceptionEncoder

Bases: RoboflowCoreModel

Roboflow Perception Encoder model implementation.

This class is responsible for handling the Percpetion Encoder model, including loading the model, preprocessing the input, and performing inference.

Attributes:

Name Type Description
model CLIP

The PE-CLIP model instance.

preprocess function

Function to preprocess the image.

tokenizer function

Function to tokenize text.

device str

The device to run inference on (cuda/cpu).

Source code in inference/models/perception_encoder/perception_encoder.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
class PerceptionEncoder(RoboflowCoreModel):
    """Roboflow Perception Encoder model implementation.

    This class is responsible for handling the Percpetion Encoder model, including
    loading the model, preprocessing the input, and performing inference.

    Attributes:
        model (pe.CLIP): The PE-CLIP model instance.
        preprocess (function): Function to preprocess the image.
        tokenizer (function): Function to tokenize text.
        device (str): The device to run inference on (cuda/cpu).
    """

    def __init__(
        self,
        model_id: str = PERCEPTION_ENCODER_MODEL_ID,
        device: str = DEVICE,
        *args,
        **kwargs,
    ):
        """Initializes the PerceptionEncoder with the given arguments and keyword arguments."""
        t1 = perf_counter()
        super().__init__(model_id=model_id.lower(), *args, **kwargs)
        self.device = device
        self.log("Creating PE-CLIP model")
        # Parse model config from model_id (format: perception-encoder/PE-Core-L14-336)
        model_config = model_id.split("/")[-1]
        checkpoint_path = os.path.join(self.cache_dir, "model.pt")
        self.model = pe.CLIP.from_config(
            model_config, pretrained=True, checkpoint_path=checkpoint_path
        )
        self.model = self.model.to(device)
        self.model.eval()

        self.preprocessor = transforms.get_image_transform(self.model.image_size)
        self.tokenizer = transforms.get_text_tokenizer(self.model.context_length)

        self.task_type = "embedding"

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference."""
        return ["model.pt"]  # No files needed as model is downloaded from HuggingFace

    def initialize_model(self, **kwargs) -> None:
        """Initialize the model. Not needed for PE-CLIP as it's loaded in __init__."""
        pass

    def preproc_image(self, image: InferenceRequestImage) -> torch.Tensor:
        """Preprocesses an inference request image."""
        pil_image = Image.fromarray(load_image_rgb(image))
        preprocessed_image = self.preprocessor(pil_image)
        return preprocessed_image.unsqueeze(0)

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[torch.Tensor, PreprocessReturnMetadata]:
        return self.preproc_image(image), PreprocessReturnMetadata({})

    def compare(
        self,
        subject: Any,
        prompt: Any,
        subject_type: str = "image",
        prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
        **kwargs,
    ) -> Union[List[float], Dict[str, float]]:
        """
        Compares the subject with the prompt to calculate similarity scores.

        Args:
            subject (Any): The subject data to be compared. Can be either an image or text.
            prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
            subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
            prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
            **kwargs: Additional keyword arguments.

        Returns:
            Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).
        """
        if subject_type == "image":
            subject_embeddings = self.embed_image(subject)
        elif subject_type == "text":
            subject_embeddings = self.embed_text(subject)
        else:
            raise ValueError(
                f"subject_type must be either 'image' or 'text', but got {subject_type}"
            )

        if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
            prompt_keys = prompt.keys()
            prompt = [prompt[k] for k in prompt_keys]
            prompt_obj = "dict"
        else:
            if not isinstance(prompt, list):
                prompt = [prompt]
            prompt_obj = "list"

        if len(prompt) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
            )

        if prompt_type == "image":
            prompt_embeddings = self.embed_image(prompt)
        elif prompt_type == "text":
            prompt_embeddings = self.embed_text(prompt)
        else:
            raise ValueError(
                f"prompt_type must be either 'image' or 'text', but got {prompt_type}"
            )

        similarities = [
            cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
        ]

        if prompt_obj == "dict":
            similarities = dict(zip(prompt_keys, similarities))

        return similarities

    def make_compare_response(
        self, similarities: Union[List[float], Dict[str, float]]
    ) -> PerceptionEncoderCompareResponse:
        """Creates a PerceptionEncoderCompareResponse object from the provided similarity data."""
        response = PerceptionEncoderCompareResponse(similarity=similarities)
        return response

    def embed_image(
        self,
        image: Any,
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds an image or a list of images using the PE-CLIP model.

        Args:
            image (Any): The image or list of images to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the image(s) as a numpy array.
        """
        t1 = perf_counter()

        if isinstance(image, list):
            if len(image) > CLIP_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
                )
            imgs = [self.preproc_image(i) for i in image]
            img_in = torch.cat(imgs, dim=0).to(self.device)
        else:
            img_in = self.preproc_image(image).to(self.device)

        if self.device == "cpu" or self.device == "mps":
            with torch.inference_mode():
                image_features, _, _ = self.model(img_in, None)
                # Convert to float32 before converting to numpy
                embeddings = image_features.float().cpu().numpy()
        else:
            with torch.inference_mode(), torch.autocast(self.device):
                image_features, _, _ = self.model(img_in, None)
                # Convert to float32 before converting to numpy
                embeddings = image_features.float().cpu().numpy()

        return embeddings

    def embed_text(
        self,
        text: Union[str, List[str]],
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds a text or a list of texts using the PE-CLIP model.

        Args:
            text (Union[str, List[str]]): The text string or list of text strings to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the text or texts as a numpy array.
        """
        if isinstance(text, list):
            texts = text
        else:
            texts = [text]

        results = []
        for texts_batch in create_batches(
            sequence=texts, batch_size=CLIP_MAX_BATCH_SIZE
        ):
            tokenized = self.tokenizer(texts_batch).to(self.device)
            # Use float32 for CPU, bfloat16 for CUDA
            if self.device == "cpu" or self.device == "mps":
                with torch.no_grad():
                    _, text_features, _ = self.model(None, tokenized)
            else:
                with torch.inference_mode(), torch.autocast(self.device):
                    _, text_features, _ = self.model(None, tokenized)

            # Convert to float32 before converting to numpy
            embeddings = text_features.float().cpu().numpy()
            results.append(embeddings)

        return np.concatenate(results, axis=0)

    def predict(self, img_in: torch.Tensor, **kwargs) -> Tuple[np.ndarray]:
        """Predict embeddings for an input tensor.

        Args:
            img_in (torch.Tensor): The input tensor to get embeddings for.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.
        """
        img_in = img_in.to(self.device)
        if self.device == "cpu" or self.device == "mps":
            with torch.inference_mode():
                image_features, _, _ = self.model(img_in, None)
        else:
            with torch.inference_mode(), torch.autocast(self.device):
                image_features, _, _ = self.model(img_in, None)

        embeddings = image_features.float().cpu().numpy()
        return (embeddings,)

    def make_embed_image_response(
        self, embeddings: np.ndarray
    ) -> PerceptionEncoderEmbeddingResponse:
        """Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object."""
        response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def make_embed_text_response(
        self, embeddings: np.ndarray
    ) -> PerceptionEncoderEmbeddingResponse:
        """Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object."""
        response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def infer_from_request(
        self, request: PerceptionEncoderInferenceRequest
    ) -> PerceptionEncoderEmbeddingResponse:
        """Routes the request to the appropriate inference function."""
        t1 = perf_counter()
        if isinstance(request, PerceptionEncoderImageEmbeddingRequest):
            infer_func = self.embed_image
            make_response_func = self.make_embed_image_response
        elif isinstance(request, PerceptionEncoderTextEmbeddingRequest):
            infer_func = self.embed_text
            make_response_func = self.make_embed_text_response
        elif isinstance(request, PerceptionEncoderCompareRequest):
            infer_func = self.compare
            make_response_func = self.make_compare_response
        else:
            raise ValueError(
                f"Request type {type(request)} is not a valid PerceptionEncoderInferenceRequest"
            )
        data = infer_func(**request.dict())
        response = make_response_func(data)
        response.time = perf_counter() - t1
        return response

    def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse:
        return [self.make_embed_image_response(embeddings)]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return [self.make_embed_image_response(predictions[0])]

    def infer(self, image: Any, **kwargs) -> Any:
        """Embeds an image"""
        return super().infer(image, **kwargs)
Functions
__init__
__init__(
    model_id=PERCEPTION_ENCODER_MODEL_ID,
    device=DEVICE,
    *args,
    **kwargs
)

Initializes the PerceptionEncoder with the given arguments and keyword arguments.

Source code in inference/models/perception_encoder/perception_encoder.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    model_id: str = PERCEPTION_ENCODER_MODEL_ID,
    device: str = DEVICE,
    *args,
    **kwargs,
):
    """Initializes the PerceptionEncoder with the given arguments and keyword arguments."""
    t1 = perf_counter()
    super().__init__(model_id=model_id.lower(), *args, **kwargs)
    self.device = device
    self.log("Creating PE-CLIP model")
    # Parse model config from model_id (format: perception-encoder/PE-Core-L14-336)
    model_config = model_id.split("/")[-1]
    checkpoint_path = os.path.join(self.cache_dir, "model.pt")
    self.model = pe.CLIP.from_config(
        model_config, pretrained=True, checkpoint_path=checkpoint_path
    )
    self.model = self.model.to(device)
    self.model.eval()

    self.preprocessor = transforms.get_image_transform(self.model.image_size)
    self.tokenizer = transforms.get_text_tokenizer(self.model.context_length)

    self.task_type = "embedding"
compare
compare(
    subject,
    prompt,
    subject_type="image",
    prompt_type="text",
    **kwargs
)

Compares the subject with the prompt to calculate similarity scores.

Parameters:

Name Type Description Default
subject Any

The subject data to be compared. Can be either an image or text.

required
prompt Any

The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.

required
subject_type str

Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".

'image'
prompt_type Union[str, List[str], Dict[str, Any]]

Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".

'text'
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[List[float], Dict[str, float]]

Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).

Source code in inference/models/perception_encoder/perception_encoder.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def compare(
    self,
    subject: Any,
    prompt: Any,
    subject_type: str = "image",
    prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
    **kwargs,
) -> Union[List[float], Dict[str, float]]:
    """
    Compares the subject with the prompt to calculate similarity scores.

    Args:
        subject (Any): The subject data to be compared. Can be either an image or text.
        prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
        subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
        prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
        **kwargs: Additional keyword arguments.

    Returns:
        Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).
    """
    if subject_type == "image":
        subject_embeddings = self.embed_image(subject)
    elif subject_type == "text":
        subject_embeddings = self.embed_text(subject)
    else:
        raise ValueError(
            f"subject_type must be either 'image' or 'text', but got {subject_type}"
        )

    if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
        prompt_keys = prompt.keys()
        prompt = [prompt[k] for k in prompt_keys]
        prompt_obj = "dict"
    else:
        if not isinstance(prompt, list):
            prompt = [prompt]
        prompt_obj = "list"

    if len(prompt) > CLIP_MAX_BATCH_SIZE:
        raise ValueError(
            f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
        )

    if prompt_type == "image":
        prompt_embeddings = self.embed_image(prompt)
    elif prompt_type == "text":
        prompt_embeddings = self.embed_text(prompt)
    else:
        raise ValueError(
            f"prompt_type must be either 'image' or 'text', but got {prompt_type}"
        )

    similarities = [
        cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
    ]

    if prompt_obj == "dict":
        similarities = dict(zip(prompt_keys, similarities))

    return similarities
embed_image
embed_image(image, **kwargs)

Embeds an image or a list of images using the PE-CLIP model.

Parameters:

Name Type Description Default
image Any

The image or list of images to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the image(s) as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def embed_image(
    self,
    image: Any,
    **kwargs,
) -> np.ndarray:
    """
    Embeds an image or a list of images using the PE-CLIP model.

    Args:
        image (Any): The image or list of images to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the image(s) as a numpy array.
    """
    t1 = perf_counter()

    if isinstance(image, list):
        if len(image) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
            )
        imgs = [self.preproc_image(i) for i in image]
        img_in = torch.cat(imgs, dim=0).to(self.device)
    else:
        img_in = self.preproc_image(image).to(self.device)

    if self.device == "cpu" or self.device == "mps":
        with torch.inference_mode():
            image_features, _, _ = self.model(img_in, None)
            # Convert to float32 before converting to numpy
            embeddings = image_features.float().cpu().numpy()
    else:
        with torch.inference_mode(), torch.autocast(self.device):
            image_features, _, _ = self.model(img_in, None)
            # Convert to float32 before converting to numpy
            embeddings = image_features.float().cpu().numpy()

    return embeddings
embed_text
embed_text(text, **kwargs)

Embeds a text or a list of texts using the PE-CLIP model.

Parameters:

Name Type Description Default
text Union[str, List[str]]

The text string or list of text strings to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the text or texts as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def embed_text(
    self,
    text: Union[str, List[str]],
    **kwargs,
) -> np.ndarray:
    """
    Embeds a text or a list of texts using the PE-CLIP model.

    Args:
        text (Union[str, List[str]]): The text string or list of text strings to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the text or texts as a numpy array.
    """
    if isinstance(text, list):
        texts = text
    else:
        texts = [text]

    results = []
    for texts_batch in create_batches(
        sequence=texts, batch_size=CLIP_MAX_BATCH_SIZE
    ):
        tokenized = self.tokenizer(texts_batch).to(self.device)
        # Use float32 for CPU, bfloat16 for CUDA
        if self.device == "cpu" or self.device == "mps":
            with torch.no_grad():
                _, text_features, _ = self.model(None, tokenized)
        else:
            with torch.inference_mode(), torch.autocast(self.device):
                _, text_features, _ = self.model(None, tokenized)

        # Convert to float32 before converting to numpy
        embeddings = text_features.float().cpu().numpy()
        results.append(embeddings)

    return np.concatenate(results, axis=0)
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Source code in inference/models/perception_encoder/perception_encoder.py
78
79
80
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference."""
    return ["model.pt"]  # No files needed as model is downloaded from HuggingFace
infer
infer(image, **kwargs)

Embeds an image

Source code in inference/models/perception_encoder/perception_encoder.py
314
315
316
def infer(self, image: Any, **kwargs) -> Any:
    """Embeds an image"""
    return super().infer(image, **kwargs)
infer_from_request
infer_from_request(request)

Routes the request to the appropriate inference function.

Source code in inference/models/perception_encoder/perception_encoder.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def infer_from_request(
    self, request: PerceptionEncoderInferenceRequest
) -> PerceptionEncoderEmbeddingResponse:
    """Routes the request to the appropriate inference function."""
    t1 = perf_counter()
    if isinstance(request, PerceptionEncoderImageEmbeddingRequest):
        infer_func = self.embed_image
        make_response_func = self.make_embed_image_response
    elif isinstance(request, PerceptionEncoderTextEmbeddingRequest):
        infer_func = self.embed_text
        make_response_func = self.make_embed_text_response
    elif isinstance(request, PerceptionEncoderCompareRequest):
        infer_func = self.compare
        make_response_func = self.make_compare_response
    else:
        raise ValueError(
            f"Request type {type(request)} is not a valid PerceptionEncoderInferenceRequest"
        )
    data = infer_func(**request.dict())
    response = make_response_func(data)
    response.time = perf_counter() - t1
    return response
initialize_model
initialize_model(**kwargs)

Initialize the model. Not needed for PE-CLIP as it's loaded in init.

Source code in inference/models/perception_encoder/perception_encoder.py
82
83
84
def initialize_model(self, **kwargs) -> None:
    """Initialize the model. Not needed for PE-CLIP as it's loaded in __init__."""
    pass
make_compare_response
make_compare_response(similarities)

Creates a PerceptionEncoderCompareResponse object from the provided similarity data.

Source code in inference/models/perception_encoder/perception_encoder.py
159
160
161
162
163
164
def make_compare_response(
    self, similarities: Union[List[float], Dict[str, float]]
) -> PerceptionEncoderCompareResponse:
    """Creates a PerceptionEncoderCompareResponse object from the provided similarity data."""
    response = PerceptionEncoderCompareResponse(similarity=similarities)
    return response
make_embed_image_response
make_embed_image_response(embeddings)

Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object.

Source code in inference/models/perception_encoder/perception_encoder.py
266
267
268
269
270
271
def make_embed_image_response(
    self, embeddings: np.ndarray
) -> PerceptionEncoderEmbeddingResponse:
    """Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object."""
    response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
    return response
make_embed_text_response
make_embed_text_response(embeddings)

Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object.

Source code in inference/models/perception_encoder/perception_encoder.py
273
274
275
276
277
278
def make_embed_text_response(
    self, embeddings: np.ndarray
) -> PerceptionEncoderEmbeddingResponse:
    """Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object."""
    response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
    return response
predict
predict(img_in, **kwargs)

Predict embeddings for an input tensor.

Parameters:

Name Type Description Default
img_in Tensor

The input tensor to get embeddings for.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def predict(self, img_in: torch.Tensor, **kwargs) -> Tuple[np.ndarray]:
    """Predict embeddings for an input tensor.

    Args:
        img_in (torch.Tensor): The input tensor to get embeddings for.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.
    """
    img_in = img_in.to(self.device)
    if self.device == "cpu" or self.device == "mps":
        with torch.inference_mode():
            image_features, _, _ = self.model(img_in, None)
    else:
        with torch.inference_mode(), torch.autocast(self.device):
            image_features, _, _ = self.model(img_in, None)

    embeddings = image_features.float().cpu().numpy()
    return (embeddings,)
preproc_image
preproc_image(image)

Preprocesses an inference request image.

Source code in inference/models/perception_encoder/perception_encoder.py
86
87
88
89
90
def preproc_image(self, image: InferenceRequestImage) -> torch.Tensor:
    """Preprocesses an inference request image."""
    pil_image = Image.fromarray(load_image_rgb(image))
    preprocessed_image = self.preprocessor(pil_image)
    return preprocessed_image.unsqueeze(0)

Functions

inference.models.perception_encoder.perception_encoder_inference_models

Classes

InferenceModelsPerceptionEncoderAdapter

Bases: Model

Roboflow Perception Encoder model implementation.

This class is responsible for handling the Percpetion Encoder model, including loading the model, preprocessing the input, and performing inference.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class InferenceModelsPerceptionEncoderAdapter(Model):
    """Roboflow Perception Encoder model implementation.

    This class is responsible for handling the Percpetion Encoder model, including
    loading the model, preprocessing the input, and performing inference.
    """

    def __init__(
        self, model_id: str = PERCEPTION_ENCODER_MODEL_ID, api_key: str = None, **kwargs
    ):
        super().__init__()
        if model_id.startswith("perception_encoder/"):
            model_id = model_id.replace("perception_encoder/", "perception-encoder/")

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY

        self.task_type = "embedding"

        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: PerceptionEncoderTorch = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )

    def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
        """Preprocesses an inference request image."""
        return load_image_bgr(image)

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[torch.Tensor, PreprocessReturnMetadata]:
        return self.preproc_image(image), PreprocessReturnMetadata({})

    def compare(
        self,
        subject: Any,
        prompt: Any,
        subject_type: str = "image",
        prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
        **kwargs,
    ) -> Union[List[float], Dict[str, float]]:
        """
        Compares the subject with the prompt to calculate similarity scores.

        Args:
            subject (Any): The subject data to be compared. Can be either an image or text.
            prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
            subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
            prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
            **kwargs: Additional keyword arguments.

        Returns:
            Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).
        """
        if subject_type == "image":
            subject_embeddings = self.embed_image(subject)
        elif subject_type == "text":
            subject_embeddings = self.embed_text(subject)
        else:
            raise ValueError(
                f"subject_type must be either 'image' or 'text', but got {subject_type}"
            )

        if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
            prompt_keys = prompt.keys()
            prompt = [prompt[k] for k in prompt_keys]
            prompt_obj = "dict"
        else:
            if not isinstance(prompt, list):
                prompt = [prompt]
            prompt_obj = "list"

        if len(prompt) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
            )

        if prompt_type == "image":
            prompt_embeddings = self.embed_image(prompt)
        elif prompt_type == "text":
            prompt_embeddings = self.embed_text(prompt)
        else:
            raise ValueError(
                f"prompt_type must be either 'image' or 'text', but got {prompt_type}"
            )

        similarities = [
            cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
        ]

        if prompt_obj == "dict":
            similarities = dict(zip(prompt_keys, similarities))

        return similarities

    def make_compare_response(
        self, similarities: Union[List[float], Dict[str, float]]
    ) -> PerceptionEncoderCompareResponse:
        """Creates a PerceptionEncoderCompareResponse object from the provided similarity data."""
        response = PerceptionEncoderCompareResponse(similarity=similarities)
        return response

    def embed_image(
        self,
        image: Any,
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds an image or a list of images using the PE-CLIP model.

        Args:
            image (Any): The image or list of images to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the image(s) as a numpy array.
        """
        if isinstance(image, list):
            if len(image) > CLIP_MAX_BATCH_SIZE:
                raise ValueError(
                    f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
                )
            img_in = [self.preproc_image(i) for i in image]
        else:
            img_in = [self.preproc_image(image)]

        return self._model.embed_images(img_in).cpu().numpy()

    def embed_text(
        self,
        text: Union[str, List[str]],
        **kwargs,
    ) -> np.ndarray:
        """
        Embeds a text or a list of texts using the PE-CLIP model.

        Args:
            text (Union[str, List[str]]): The text string or list of text strings to be embedded.
            **kwargs: Additional keyword arguments.

        Returns:
            np.ndarray: The embeddings of the text or texts as a numpy array.
        """
        if isinstance(text, list):
            texts = text
        else:
            texts = [text]
        if len(texts) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of texts that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
            )
        return self._model.embed_text(texts).cpu().numpy()

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        """Predict embeddings for an input tensor.

        Args:
            img_in (torch.Tensor): The input tensor to get embeddings for.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.
        """
        embeddings = self._model.embed_images(img_in).cpu().numpy()
        return (embeddings,)

    def make_embed_image_response(
        self, embeddings: np.ndarray
    ) -> PerceptionEncoderEmbeddingResponse:
        """Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object."""
        response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def make_embed_text_response(
        self, embeddings: np.ndarray
    ) -> PerceptionEncoderEmbeddingResponse:
        """Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object."""
        response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
        return response

    def infer_from_request(
        self, request: PerceptionEncoderInferenceRequest
    ) -> PerceptionEncoderEmbeddingResponse:
        """Routes the request to the appropriate inference function."""
        t1 = perf_counter()
        if isinstance(request, PerceptionEncoderImageEmbeddingRequest):
            infer_func = self.embed_image
            make_response_func = self.make_embed_image_response
        elif isinstance(request, PerceptionEncoderTextEmbeddingRequest):
            infer_func = self.embed_text
            make_response_func = self.make_embed_text_response
        elif isinstance(request, PerceptionEncoderCompareRequest):
            infer_func = self.compare
            make_response_func = self.make_compare_response
        else:
            raise ValueError(
                f"Request type {type(request)} is not a valid PerceptionEncoderInferenceRequest"
            )
        data = infer_func(**request.dict())
        response = make_response_func(data)
        response.time = perf_counter() - t1
        return response

    def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse:
        return [self.make_embed_image_response(embeddings)]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> Any:
        return [self.make_embed_image_response(predictions[0])]

    def infer(self, image: Any, **kwargs) -> Any:
        """Embeds an image"""
        return super().infer(image, **kwargs)
Functions
compare
compare(
    subject,
    prompt,
    subject_type="image",
    prompt_type="text",
    **kwargs
)

Compares the subject with the prompt to calculate similarity scores.

Parameters:

Name Type Description Default
subject Any

The subject data to be compared. Can be either an image or text.

required
prompt Any

The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.

required
subject_type str

Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".

'image'
prompt_type Union[str, List[str], Dict[str, Any]]

Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".

'text'
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Union[List[float], Dict[str, float]]

Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def compare(
    self,
    subject: Any,
    prompt: Any,
    subject_type: str = "image",
    prompt_type: Union[str, List[str], Dict[str, Any]] = "text",
    **kwargs,
) -> Union[List[float], Dict[str, float]]:
    """
    Compares the subject with the prompt to calculate similarity scores.

    Args:
        subject (Any): The subject data to be compared. Can be either an image or text.
        prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values.
        subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image".
        prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text".
        **kwargs: Additional keyword arguments.

    Returns:
        Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s).
    """
    if subject_type == "image":
        subject_embeddings = self.embed_image(subject)
    elif subject_type == "text":
        subject_embeddings = self.embed_text(subject)
    else:
        raise ValueError(
            f"subject_type must be either 'image' or 'text', but got {subject_type}"
        )

    if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt):
        prompt_keys = prompt.keys()
        prompt = [prompt[k] for k in prompt_keys]
        prompt_obj = "dict"
    else:
        if not isinstance(prompt, list):
            prompt = [prompt]
        prompt_obj = "list"

    if len(prompt) > CLIP_MAX_BATCH_SIZE:
        raise ValueError(
            f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}"
        )

    if prompt_type == "image":
        prompt_embeddings = self.embed_image(prompt)
    elif prompt_type == "text":
        prompt_embeddings = self.embed_text(prompt)
    else:
        raise ValueError(
            f"prompt_type must be either 'image' or 'text', but got {prompt_type}"
        )

    similarities = [
        cosine_similarity(subject_embeddings, p) for p in prompt_embeddings
    ]

    if prompt_obj == "dict":
        similarities = dict(zip(prompt_keys, similarities))

    return similarities
embed_image
embed_image(image, **kwargs)

Embeds an image or a list of images using the PE-CLIP model.

Parameters:

Name Type Description Default
image Any

The image or list of images to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the image(s) as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def embed_image(
    self,
    image: Any,
    **kwargs,
) -> np.ndarray:
    """
    Embeds an image or a list of images using the PE-CLIP model.

    Args:
        image (Any): The image or list of images to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the image(s) as a numpy array.
    """
    if isinstance(image, list):
        if len(image) > CLIP_MAX_BATCH_SIZE:
            raise ValueError(
                f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
            )
        img_in = [self.preproc_image(i) for i in image]
    else:
        img_in = [self.preproc_image(image)]

    return self._model.embed_images(img_in).cpu().numpy()
embed_text
embed_text(text, **kwargs)

Embeds a text or a list of texts using the PE-CLIP model.

Parameters:

Name Type Description Default
text Union[str, List[str]]

The text string or list of text strings to be embedded.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

np.ndarray: The embeddings of the text or texts as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def embed_text(
    self,
    text: Union[str, List[str]],
    **kwargs,
) -> np.ndarray:
    """
    Embeds a text or a list of texts using the PE-CLIP model.

    Args:
        text (Union[str, List[str]]): The text string or list of text strings to be embedded.
        **kwargs: Additional keyword arguments.

    Returns:
        np.ndarray: The embeddings of the text or texts as a numpy array.
    """
    if isinstance(text, list):
        texts = text
    else:
        texts = [text]
    if len(texts) > CLIP_MAX_BATCH_SIZE:
        raise ValueError(
            f"The maximum number of texts that can be embedded at once is {CLIP_MAX_BATCH_SIZE}"
        )
    return self._model.embed_text(texts).cpu().numpy()
infer
infer(image, **kwargs)

Embeds an image

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
269
270
271
def infer(self, image: Any, **kwargs) -> Any:
    """Embeds an image"""
    return super().infer(image, **kwargs)
infer_from_request
infer_from_request(request)

Routes the request to the appropriate inference function.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def infer_from_request(
    self, request: PerceptionEncoderInferenceRequest
) -> PerceptionEncoderEmbeddingResponse:
    """Routes the request to the appropriate inference function."""
    t1 = perf_counter()
    if isinstance(request, PerceptionEncoderImageEmbeddingRequest):
        infer_func = self.embed_image
        make_response_func = self.make_embed_image_response
    elif isinstance(request, PerceptionEncoderTextEmbeddingRequest):
        infer_func = self.embed_text
        make_response_func = self.make_embed_text_response
    elif isinstance(request, PerceptionEncoderCompareRequest):
        infer_func = self.compare
        make_response_func = self.make_compare_response
    else:
        raise ValueError(
            f"Request type {type(request)} is not a valid PerceptionEncoderInferenceRequest"
        )
    data = infer_func(**request.dict())
    response = make_response_func(data)
    response.time = perf_counter() - t1
    return response
make_compare_response
make_compare_response(similarities)

Creates a PerceptionEncoderCompareResponse object from the provided similarity data.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
150
151
152
153
154
155
def make_compare_response(
    self, similarities: Union[List[float], Dict[str, float]]
) -> PerceptionEncoderCompareResponse:
    """Creates a PerceptionEncoderCompareResponse object from the provided similarity data."""
    response = PerceptionEncoderCompareResponse(similarity=similarities)
    return response
make_embed_image_response
make_embed_image_response(embeddings)

Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
221
222
223
224
225
226
def make_embed_image_response(
    self, embeddings: np.ndarray
) -> PerceptionEncoderEmbeddingResponse:
    """Converts the given embeddings into a PerceptionEncoderEmbeddingResponse object."""
    response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
    return response
make_embed_text_response
make_embed_text_response(embeddings)

Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
228
229
230
231
232
233
def make_embed_text_response(
    self, embeddings: np.ndarray
) -> PerceptionEncoderEmbeddingResponse:
    """Converts the given text embeddings into a PerceptionEncoderEmbeddingResponse object."""
    response = PerceptionEncoderEmbeddingResponse(embeddings=embeddings.tolist())
    return response
predict
predict(img_in, **kwargs)

Predict embeddings for an input tensor.

Parameters:

Name Type Description Default
img_in Tensor

The input tensor to get embeddings for.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
208
209
210
211
212
213
214
215
216
217
218
219
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
    """Predict embeddings for an input tensor.

    Args:
        img_in (torch.Tensor): The input tensor to get embeddings for.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray]: A tuple containing the embeddings as a numpy array.
    """
    embeddings = self._model.embed_images(img_in).cpu().numpy()
    return (embeddings,)
preproc_image
preproc_image(image)

Preprocesses an inference request image.

Source code in inference/models/perception_encoder/perception_encoder_inference_models.py
79
80
81
def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
    """Preprocesses an inference request image."""
    return load_image_bgr(image)

Functions

models/perception_encoder/vision_encoder

inference.models.perception_encoder.vision_encoder.config

Include all available vision encoder configurations.

Classes

PEConfig dataclass

Vision Tower Config.

Source code in inference/models/perception_encoder/vision_encoder/config.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass
class PEConfig:
    """Vision Tower Config."""

    patch_size: int
    width: int
    layers: int
    heads: int
    mlp_ratio: float
    output_dim: Optional[int]

    ls_init_value: float = None
    drop_path: float = 0.0

    image_size: int = (224,)
    use_abs_posemb: bool = True
    use_cls_token: bool = False
    use_rope2d: bool = True

    pool_type: str = "attn"
    attn_pooler_heads: int = 8

    use_ln_pre: bool = True
    use_ln_post: bool = True

PETextConfig dataclass

Text Tower Config.

Source code in inference/models/perception_encoder/vision_encoder/config.py
55
56
57
58
59
60
61
62
63
64
65
66
67
@dataclass
class PETextConfig:
    """Text Tower Config."""

    context_length: int
    width: int
    heads: int
    layers: int

    output_dim: int

    mlp_ratio: float = 4.0
    vocab_size: int = 49408

inference.models.perception_encoder.vision_encoder.pe

Classes

SelfAttention

Bases: Module

Implements sequence packed attention and RoPe

Source code in inference/models/perception_encoder/vision_encoder/pe.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class SelfAttention(nn.Module):
    r"""
    Implements sequence packed attention and RoPe
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        rope: Optional[nn.Module] = None,
    ):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        # To make this compatibile with nn.MultiHeadAttention
        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
        self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

        self.rope = rope
        self.scale = self.head_dim ** (-0.5)

    def init_tensors(self):
        xavier_uniform_(self.in_proj_weight)
        constant_(self.in_proj_bias, 0.0)
        constant_(self.out_proj.bias, 0.0)

    def forward(self, x, attn_mask=None):
        batch, seq, embed_dim = x.shape
        proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)

        # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
        proj = (
            proj.unflatten(-1, (3, embed_dim))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
            .contiguous()
        )
        q, k, v = proj[0], proj[1], proj[2]

        # Use "q_" so that we don't accidentally quit in pdb :)
        q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
        k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
        v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)

        if self.rope:
            q, k = self.rope(q, k)

        attn = F.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
        )
        attn = rearrange(attn, "b h s d -> b s (h d)")

        return F.linear(attn, self.out_proj.weight, self.out_proj.bias)

Transformer

Bases: Module

Source code in inference/models/perception_encoder/vision_encoder/pe.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class Transformer(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        drop_path: float = 0.0,
        rope: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = False

        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    ls_init_value=ls_init_value,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    drop_path=drop_path,
                    rope=rope,
                )
                for _ in range(layers)
            ]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    @torch.jit.ignore
    def truncate(self, layer_idx: int):
        """Delete layers so the last layer is the given layer index."""
        self.layers = ((self.layers + layer_idx) % self.layers) + 1
        self.resblocks = nn.ModuleList(self.resblocks[: self.layers])

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        layer_idx: int = -1,
    ):
        stop_idx = (self.layers + layer_idx) % self.layers

        for i, r in enumerate(self.resblocks):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
                x = checkpoint(r, x, None, None, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask)

            if i == stop_idx:
                break

        return x
Functions
truncate
truncate(layer_idx)

Delete layers so the last layer is the given layer index.

Source code in inference/models/perception_encoder/vision_encoder/pe.py
269
270
271
272
273
@torch.jit.ignore
def truncate(self, layer_idx: int):
    """Delete layers so the last layer is the given layer index."""
    self.layers = ((self.layers + layer_idx) % self.layers) + 1
    self.resblocks = nn.ModuleList(self.resblocks[: self.layers])

VisionTransformer

Bases: Module

Source code in inference/models/perception_encoder/vision_encoder/pe.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
class VisionTransformer(nn.Module):
    def __init__(
        self,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
        use_ln_pre: bool = True,
        use_ln_post: bool = True,
        ls_init_value: float = None,
        drop_path: float = 0.0,
        image_size: int = 448,  # Pretrain image size only; you can pass in any image size
        use_abs_posemb: bool = True,
        use_rope2d: bool = True,
        use_cls_token: bool = False,
        output_dim: Optional[int] = 1280,
        attn_pooler_heads: int = 8,
        pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
    ):
        super().__init__()
        assert pool_type in ("attn", "tok", "avg", "none")
        self.pool_type = pool_type
        self.patch_size = patch_size

        self.output_dim = output_dim or width
        self.proj_dim = output_dim
        self.heads = heads
        self.width = width
        self.layers = layers

        self.use_abs_posemb = use_abs_posemb
        self.use_cls_token = use_cls_token
        self.use_rope2d = use_rope2d
        self.image_size = image_size

        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=width,
            kernel_size=patch_size,
            stride=patch_size,
            bias=False,
        )
        self.rope = (
            Rope2D(
                dim=width // heads,
                use_cls_token=self.use_cls_token,
            )
            if self.use_rope2d
            else None
        )

        self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
        self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()

        self.transformer = Transformer(
            width,
            layers,
            heads,
            mlp_ratio,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
            drop_path=drop_path,
            rope=self.rope,
        )

        if pool_type == "attn":
            self.attn_pool = AttentionPooling(
                embed_dim=width,
                num_heads=attn_pooler_heads,
                act_layer=act_layer,
                norm_layer=norm_layer,
            )
        else:
            self.attn_pool = None

        self.init_tensors()

    def init_tensors(self):
        def init_submodule_tensors(module):
            for name, child in module.named_children():
                if hasattr(child, "init_tensors"):
                    logger.debug(f"Initializing tensors for submodule: {name}")
                    child.init_tensors()
                init_submodule_tensors(child)

        init_submodule_tensors(self)
        self.rope.init_tensors()

        # class embeddings and positional embeddings
        init_scale = self.width**-0.5

        if self.use_cls_token:
            self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))

        if self.use_abs_posemb:
            self.posemb_grid_size = self.image_size // self.patch_size
            self.positional_embedding = nn.Parameter(
                init_scale
                * torch.randn(
                    int(self.use_cls_token) + self.posemb_grid_size**2, self.width
                )
            )

        if self.proj_dim is not None:
            self.proj = nn.Parameter(
                init_scale * torch.randn(self.width, self.proj_dim)
            )

    def load_ckpt(self, ckpt_path: str):
        _sd = torch.load(ckpt_path, weights_only=True)
        if "state_dict" in _sd:
            _sd = _sd["state_dict"]
        elif "weights" in _sd:
            _sd = _sd["weights"]

        # for backwards compatibility
        _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
        if any(k.startswith("visual.") for k in _sd):
            _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}

        m, u = self.load_state_dict(_sd, strict=False)
        logger.info(f"Missing keys for loading vision encoder: {m}")
        logger.info(f"Unexpected keys for loading vision encoder: {u}")
        print(f"Missing keys for loading vision encoder: {m}")
        print(f"Unexpected keys for loading vision encoder: {u}")

    def truncate(self, layer_idx: int):
        """Delete layers so the last layer is the given layer index."""
        self.transformer.truncate(layer_idx)
        self.layers = self.transformer.layers

    @classmethod
    def from_config(
        cls,
        name: str,
        pretrained: bool = False,
        checkpoint_path: Optional[str] = None,
        **kwdargs,
    ):
        if name not in PE_VISION_CONFIG:
            raise RuntimeError(f"{name} not found in configs.")

        args = asdict(PE_VISION_CONFIG[name])
        args.update(kwdargs)

        model = cls(**args)
        if pretrained:
            model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))

        return model

    @classmethod
    def available_configs(cls):
        return list(PE_VISION_CONFIG.keys())

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.set_grad_checkpointing(enable=enable)

    def _sample_abs_posemb(self, grid_h: int, grid_w: int):
        """Interpolates the absolute position embedding if necessary."""
        if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
            return self.positional_embedding[None, ...]

        pos_embed = self.positional_embedding
        if self.use_cls_token:
            cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]

        pos_embed = (
            pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
            .permute(0, 3, 1, 2)
            .contiguous()
        )
        pos_embed = F.interpolate(
            pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()

        if self.use_cls_token:
            pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)

        return pos_embed[None, ...]

    def _pool(self, x: torch.Tensor):
        if self.pool_type == "tok":
            return x[:, 0]
        elif self.pool_type == "avg":
            return x.mean(dim=1)
        elif self.pool_type == "attn":
            return self.attn_pool(x).squeeze(1)
        elif self.pool_type == "none":
            return x
        else:
            raise NotImplementedError

    def forward_features(
        self,
        x: torch.Tensor,
        norm: bool = False,
        layer_idx: int = -1,
        strip_cls_token: bool = False,
    ):
        batch, _, h, w = x.shape
        grid_h, grid_w = h // self.patch_size, w // self.patch_size

        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)

        if self.use_cls_token:
            x = torch.cat(
                [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
                dim=1,
            )

        if self.use_abs_posemb:
            x = x + self._sample_abs_posemb(grid_h, grid_w)

        if self.use_rope2d:
            self.rope.update_grid(x.device, grid_h, grid_w)

        x = self.ln_pre(x)
        x = self.transformer(x, layer_idx=layer_idx)

        if norm:
            x = self.ln_post(x)

        if strip_cls_token and self.use_cls_token:
            x = x[:, 1:, :]

        return x

    def forward(self, x: torch.Tensor, **kwargs):
        x = self.forward_features(x, norm=True, **kwargs)
        x = self._pool(x)

        if self.proj_dim is not None:
            x = x @ self.proj

        return x
Functions
truncate
truncate(layer_idx)

Delete layers so the last layer is the given layer index.

Source code in inference/models/perception_encoder/vision_encoder/pe.py
426
427
428
429
def truncate(self, layer_idx: int):
    """Delete layers so the last layer is the given layer index."""
    self.transformer.truncate(layer_idx)
    self.layers = self.transformer.layers

inference.models.perception_encoder.vision_encoder.rope

Classes

Rope2D

Helper class to apply RoPE2D as well as interpolate on the fly.

Source code in inference/models/perception_encoder/vision_encoder/rope.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class Rope2D:
    """Helper class to apply RoPE2D as well as interpolate on the fly."""

    def __init__(self, dim, use_cls_token=False):
        self.dim = dim
        self.use_cls_token = use_cls_token
        self.grid_size = None
        self.freq = None

    def init_tensors(self):
        self.rope = RotaryEmbedding(self.dim // 2)

    def update_grid(self, device, grid_h, grid_w):
        if self.grid_size != (grid_h, grid_w):
            self.grid_size = (grid_h, grid_w)

            self.rope = self.rope.to(device)

            if self.use_cls_token:
                # +1 to leave space for the cls token to be (0, 0)
                grid_y_range = torch.arange(grid_h, device=device) + 1
                grid_x_range = torch.arange(grid_w, device=device) + 1
            else:
                grid_y_range = torch.arange(grid_h, device=device)
                grid_x_range = torch.arange(grid_w, device=device)

            freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
            freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
            freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)

            if self.use_cls_token:
                freq = torch.cat(
                    [torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
                )

            self.freq = freq[None, ...]

        self.freq = self.freq.to(device)

    def __call__(self, q, k):
        # batch, heads, seq, dim = q.shape
        q = apply_rotary_emb(self.freq[:, None, :, :], q)
        k = apply_rotary_emb(self.freq[:, None, :, :], k)

        return q, k

inference.models.perception_encoder.vision_encoder.tokenizer

CLIP tokenizer

Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.

Classes

SimpleTokenizer

Bases: object

Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class SimpleTokenizer(object):
    def __init__(
        self,
        bpe_path: str = default_bpe(),
        additional_special_tokens: Optional[List[str]] = None,
        context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
        clean: str = "lower",
        reduction_mask: str = "",
    ):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
        merges = merges[1 : 49152 - 256 - 2 + 1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v + "</w>" for v in vocab]
        for merge in merges:
            vocab.append("".join(merge))
        special_tokens = ["<start_of_text>", "<end_of_text>"]
        if additional_special_tokens:
            special_tokens += additional_special_tokens
        vocab.extend(special_tokens)
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {t: t for t in special_tokens}
        special = "|".join(special_tokens)
        self.pat = re.compile(
            special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
            re.IGNORECASE,
        )
        self.vocab_size = len(self.encoder)
        self.all_special_ids = [self.encoder[t] for t in special_tokens]
        self.sot_token_id = self.all_special_ids[0]
        self.eot_token_id = self.all_special_ids[1]
        self.context_length = context_length
        self.clean_fn = get_clean_fn(clean)
        self.reduction_fn = (
            get_reduction_mask_fn(reduction_mask) if reduction_mask else None
        )

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        pairs = get_pairs(word)

        if not pairs:
            return token + "</w>"

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = self.clean_fn(text)
        for token in re.findall(self.pat, text):
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            bpe_tokens.extend(
                self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
            )
        return bpe_tokens

    def decode(self, tokens):
        text = "".join([self.decoder[token] for token in tokens])
        text = (
            bytearray([self.byte_decoder[c] for c in text])
            .decode("utf-8", errors="replace")
            .replace("</w>", " ")
        )
        return text

    def __call__(
        self, texts: Union[str, List[str]], context_length: Optional[int] = None
    ) -> torch.LongTensor:
        """Returns the tokenized representation of given input string(s)

        Parameters
        ----------
        texts : Union[str, List[str]]
            An input string or a list of input strings to tokenize
        context_length : int
            The context length to use; all CLIP models use 77 as the context length

        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
        """
        if isinstance(texts, str):
            texts = [texts]

        context_length = context_length or self.context_length
        assert context_length, "Please set a valid context length"

        if self.reduction_fn is not None:
            # use reduction strategy for tokenize if set, otherwise default to truncation below
            return self.reduction_fn(
                texts,
                context_length=context_length,
                sot_token_id=self.sot_token_id,
                eot_token_id=self.eot_token_id,
                encode_fn=self.encode,
            )

        all_tokens = [
            [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
            for text in texts
        ]
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                tokens = tokens[:context_length]  # Truncate
                tokens[-1] = self.eot_token_id
            result[i, : len(tokens)] = torch.tensor(tokens)

        return result
Functions
__call__
__call__(texts, context_length=None)

Returns the tokenized representation of given input string(s)

Parameters

texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length

Returns

A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]

Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def __call__(
    self, texts: Union[str, List[str]], context_length: Optional[int] = None
) -> torch.LongTensor:
    """Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize
    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
    """
    if isinstance(texts, str):
        texts = [texts]

    context_length = context_length or self.context_length
    assert context_length, "Please set a valid context length"

    if self.reduction_fn is not None:
        # use reduction strategy for tokenize if set, otherwise default to truncation below
        return self.reduction_fn(
            texts,
            context_length=context_length,
            sot_token_id=self.sot_token_id,
            eot_token_id=self.eot_token_id,
            encode_fn=self.encode,
        )

    all_tokens = [
        [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
        for text in texts
    ]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            tokens = tokens[:context_length]  # Truncate
            tokens[-1] = self.eot_token_id
        result[i, : len(tokens)] = torch.tensor(tokens)

    return result

Functions

bytes_to_unicode cached

bytes_to_unicode()

Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.

Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a significant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    # Precompute set for faster lookup, eliminate repeated containment checks
    bs_set = set(bs)
    cs = bs[:]
    n = 0
    for b in range(256):  # 2**8
        if b not in bs_set:
            bs.append(b)
            cs.append(256 + n)
            n += 1
    # Use list comprehension and map for efficient conversion
    cs = list(map(chr, cs))
    return dict(zip(bs, cs))

canonicalize_text

canonicalize_text(
    text, *, keep_punctuation_exact_string=None
)

Returns canonicalized text (lowercase and punctuation removed).

From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94

Parameters:

Name Type Description Default
text

string to be canonicalized.

required
keep_punctuation_exact_string

If provided, then this exact string kept. For example providing '{}' will keep any occurrences of '{}' (but will still remove '{' and '}' that appear separately).

None
Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
    """Returns canonicalized `text` (lowercase and punctuation removed).

    From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94

    Args:
      text: string to be canonicalized.
      keep_punctuation_exact_string: If provided, then this exact string kept.
        For example providing '{}' will keep any occurrences of '{}' (but will
        still remove '{' and '}' that appear separately).
    """
    text = text.replace("_", " ")
    if keep_punctuation_exact_string:
        text = keep_punctuation_exact_string.join(
            part.translate(str.maketrans("", "", string.punctuation))
            for part in text.split(keep_punctuation_exact_string)
        )
    else:
        text = text.translate(str.maketrans("", "", string.punctuation))
    text = text.lower()
    text = re.sub(r"\s+", " ", text)
    return text.strip()

get_pairs

get_pairs(word)

Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings).

Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
62
63
64
65
66
67
68
69
70
71
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

get_reduction_mask_fn

get_reduction_mask_fn(type)

Choose strategy for dropping (masking) tokens to achieve target context length

Source code in inference/models/perception_encoder/vision_encoder/tokenizer.py
335
336
337
338
339
340
341
342
343
344
345
def get_reduction_mask_fn(type: str):
    """Choose strategy for dropping (masking) tokens to achieve target context length"""
    assert type in ("simple", "random", "shuffle")
    if type == "simple":
        return simple_mask_tokenize  # randomly select block [start:end]
    elif type == "random":
        return random_mask_tokenize  # randomly drop tokens (keep order)
    elif type == "shuffle":
        return partial(
            random_mask_tokenize, shuffle=True
        )  # randomly drop tokens (shuffle order)

models/qwen25vl

inference.models.qwen25vl.qwen25vl

Classes

models/resnet

inference.models.resnet.resnet_classification

Classes

ResNetClassification

Bases: ClassificationBaseOnnxRoboflowInferenceModel

VitClassification handles classification inference for Vision Transformer (ViT) models using ONNX.

Inherits

Attributes:

Name Type Description
multiclass bool

A flag that specifies if the model should handle multiclass classification.

Source code in inference/models/resnet/resnet_classification.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class ResNetClassification(ClassificationBaseOnnxRoboflowInferenceModel):
    """VitClassification handles classification inference
    for Vision Transformer (ViT) models using ONNX.

    Inherits:
        ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference.
        ClassificationMixin: Mixin class providing classification-specific methods.

    Attributes:
        multiclass (bool): A flag that specifies if the model should handle multiclass classification.
    """

    preprocess_means = [0.485, 0.456, 0.406]
    preprocess_stds = [0.229, 0.224, 0.225]

    def __init__(self, *args, **kwargs):
        """Initializes the VitClassification instance.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.multiclass = self.environment.get("MULTICLASS", False)

    @property
    def weights_file(self) -> str:
        """Determines the weights file to be used based on the availability of AWS keys.

        If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'.
        Otherwise, it returns the path to 'best.onnx'.

        Returns:
            str: Path to the weights file.
        """
        if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA:
            return "weights.onnx"
        else:
            return "best.onnx"
Attributes
weights_file property
weights_file

Determines the weights file to be used based on the availability of AWS keys.

If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. Otherwise, it returns the path to 'best.onnx'.

Returns:

Name Type Description
str str

Path to the weights file.

Functions
__init__
__init__(*args, **kwargs)

Initializes the VitClassification instance.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/resnet/resnet_classification.py
22
23
24
25
26
27
28
29
30
def __init__(self, *args, **kwargs):
    """Initializes the VitClassification instance.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self.multiclass = self.environment.get("MULTICLASS", False)

models/rfdetr

inference.models.rfdetr.rfdetr

Classes

RFDETRInstanceSegmentation

Bases: RFDETRObjectDetection, InstanceSegmentationBaseOnnxRoboflowInferenceModel

Source code in inference/models/rfdetr/rfdetr.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
class RFDETRInstanceSegmentation(
    RFDETRObjectDetection, InstanceSegmentationBaseOnnxRoboflowInferenceModel
):
    task_type = "instance-segmentation"

    def initialize_model(self, **kwargs) -> None:
        super().initialize_model(**kwargs)
        mask_shape = self.onnx_session.get_outputs()[2].shape
        self.mask_shape = mask_shape[2:]

    def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session with the RFDETR model.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        bboxes = predictions[0]
        logits = predictions[1]
        masks = predictions[2]

        return (bboxes, logits, masks)

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preproc_return_metadata: PreprocessReturnMetadata,
        confidence: float = DEFAULT_CONFIDENCE,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        bboxes, logits, masks = predictions
        bboxes = bboxes.astype(np.float32)
        logits = logits.astype(np.float32)

        batch_size, num_queries, num_classes = logits.shape
        logits_sigmoid = self.sigmoid_stable(logits)

        img_dims = preproc_return_metadata["img_dims"]

        processed_predictions = []
        processed_masks = []

        for batch_idx in range(batch_size):
            orig_h, orig_w = img_dims[batch_idx]

            logits_flat = logits_sigmoid[batch_idx].reshape(-1)

            # Use argpartition for better performance when max_detections is smaller than logits_flat
            if len(logits_flat) > max_detections:
                partition_indices = np.argpartition(-logits_flat, max_detections)[
                    :max_detections
                ]
                sorted_indices = partition_indices[
                    np.argsort(-logits_flat[partition_indices])
                ]
            else:
                sorted_indices = np.argsort(-logits_flat)
            topk_scores = logits_flat[sorted_indices]

            conf_mask = topk_scores > confidence
            sorted_indices = sorted_indices[conf_mask]
            topk_scores = topk_scores[conf_mask]

            topk_boxes = sorted_indices // num_classes
            topk_labels = sorted_indices % num_classes

            if self.is_one_indexed:
                class_filter_mask = topk_labels != self.background_class_index

                topk_labels[topk_labels > self.background_class_index] -= 1
                topk_scores = topk_scores[class_filter_mask]
                topk_labels = topk_labels[class_filter_mask]
                topk_boxes = topk_boxes[class_filter_mask]

            selected_boxes = bboxes[batch_idx, topk_boxes]
            selected_masks = masks[batch_idx, topk_boxes]

            cxcy = selected_boxes[:, :2]
            wh = selected_boxes[:, 2:]
            xy_min = cxcy - 0.5 * wh
            xy_max = cxcy + 0.5 * wh
            boxes_xyxy = np.concatenate([xy_min, xy_max], axis=1)

            if self.resize_method == "Stretch to":
                scale_fct = np.array([orig_w, orig_h, orig_w, orig_h], dtype=np.float32)
                boxes_xyxy *= scale_fct
            else:
                if self._needs_nonsquare_preproc:
                    input_h, input_w = self._preproc_resize_h, self._preproc_resize_w
                else:
                    input_h, input_w = self.img_size_h, self.img_size_w

                scale = min(input_w / orig_w, input_h / orig_h)
                scaled_w = int(orig_w * scale)
                scaled_h = int(orig_h * scale)

                pad_x = (input_w - scaled_w) / 2
                pad_y = (input_h - scaled_h) / 2

                boxes_input = boxes_xyxy * np.array(
                    [input_w, input_h, input_w, input_h], dtype=np.float32
                )

                boxes_input[:, 0] -= pad_x
                boxes_input[:, 1] -= pad_y
                boxes_input[:, 2] -= pad_x
                boxes_input[:, 3] -= pad_y

                boxes_xyxy = boxes_input / scale

            np.clip(
                boxes_xyxy,
                [0, 0, 0, 0],
                [orig_w, orig_h, orig_w, orig_h],
                out=boxes_xyxy,
            )

            batch_predictions = np.column_stack(
                (
                    boxes_xyxy,
                    topk_scores,
                    np.zeros((len(topk_scores), 1), dtype=np.float32),
                    topk_labels,
                )
            )
            valid_pred_mask = batch_predictions[:, 6] < len(self.class_names)

            outputs_predictions = []
            outputs_polygons = []
            class_filter_local = kwargs.get("class_filter")
            for i, pred in enumerate(batch_predictions):
                if not valid_pred_mask[i]:
                    continue
                # Early class filtering to avoid unnecessary mask processing
                if class_filter_local:
                    try:
                        pred_class_name = self.class_names[int(pred[6])]
                    except Exception:
                        continue
                    if pred_class_name not in class_filter_local:
                        continue
                mask = selected_masks[i]

                if self.resize_method != "Stretch to":
                    if self._needs_nonsquare_preproc:
                        input_h, input_w = (
                            self._preproc_resize_h,
                            self._preproc_resize_w,
                        )
                    else:
                        input_h, input_w = self.img_size_h, self.img_size_w
                    mask_h, mask_w = mask.shape[0], mask.shape[1]

                    letterbox_scale = min(input_w / orig_w, input_h / orig_h)
                    scaled_w = int(orig_w * letterbox_scale)
                    scaled_h = int(orig_h * letterbox_scale)

                    pad_x_input = (input_w - scaled_w) / 2
                    pad_y_input = (input_h - scaled_h) / 2

                    crop_x1 = int(round(pad_x_input * mask_w / input_w))
                    crop_y1 = int(round(pad_y_input * mask_h / input_h))
                    crop_x2 = int(round((pad_x_input + scaled_w) * mask_w / input_w))
                    crop_y2 = int(round((pad_y_input + scaled_h) * mask_h / input_h))

                    mask = mask[crop_y1:crop_y2, crop_x1:crop_x2]

                mask_decode_mode = kwargs.get("mask_decode_mode", "accurate")
                if mask_decode_mode == "accurate":
                    target_res = (orig_w, orig_h)
                    if mask.shape[1] != target_res[0] or mask.shape[0] != target_res[1]:
                        mask = cv2.resize(
                            mask.astype(np.float32),
                            target_res,
                            interpolation=cv2.INTER_LINEAR,
                        )
                elif mask_decode_mode == "tradeoff":
                    tradeoff_factor = kwargs.get("tradeoff_factor", 0.0)
                    mask_res = (mask.shape[1], mask.shape[0])  # (w, h)
                    full_res = (orig_w, orig_h)  # (w, h)
                    target_res = (
                        int(
                            mask_res[0] * (1 - tradeoff_factor)
                            + full_res[0] * tradeoff_factor
                        ),
                        int(
                            mask_res[1] * (1 - tradeoff_factor)
                            + full_res[1] * tradeoff_factor
                        ),
                    )
                    if mask.shape[1] != target_res[0] or mask.shape[0] != target_res[1]:
                        mask = cv2.resize(
                            mask.astype(np.float32),
                            target_res,
                            interpolation=cv2.INTER_LINEAR,
                        )

                mask_bin = (mask > 0).astype(np.uint8)
                points = mask2poly(mask_bin)

                # After letterbox cropping, both paths reduce to a simple
                # linear rescale from prediction dims to original dims.
                new_points = []
                prediction_h, prediction_w = mask_bin.shape[0], mask_bin.shape[1]
                for point in points:
                    new_x = point[0] * (orig_w / prediction_w)
                    new_y = point[1] * (orig_h / prediction_h)
                    new_points.append(np.array([new_x, new_y]))
                outputs_polygons.append(new_points)
                outputs_predictions.append(list(pred))

            processed_predictions.append(outputs_predictions)
            processed_masks.append(outputs_polygons)

        res = self.make_response(
            processed_predictions, processed_masks, img_dims, **kwargs
        )
        return res

    def make_response(
        self,
        predictions: List[List[List[float]]],
        masks: List[List[List[np.ndarray]]],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        """Constructs instance segmentation response objects from preprocessed predictions and polygons."""
        # Align to actual number of real images; predictions/masks may include padded slots
        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]
        effective_len = min(len(img_dims), len(predictions), len(masks))

        responses = []
        for ind in range(effective_len):
            batch_predictions = predictions[ind]
            batch_masks = masks[ind]
            preds_out = []
            for pred, mask in zip(batch_predictions, batch_masks):
                if class_filter and self.class_names[int(pred[6])] not in class_filter:
                    continue
                preds_out.append(
                    InstanceSegmentationPrediction(
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "confidence": pred[4],
                            "class": self.class_names[int(pred[6])],
                            "class_id": int(pred[6]),
                            "points": [Point(x=point[0], y=point[1]) for point in mask],
                        }
                    )
                )
            responses.append(
                InstanceSegmentationInferenceResponse(
                    predictions=preds_out,
                    image=InferenceResponseImage(
                        width=img_dims[ind][1], height=img_dims[ind][0]
                    ),
                )
            )
        return responses
Functions
make_response
make_response(
    predictions,
    masks,
    img_dims,
    class_filter=None,
    *args,
    **kwargs
)

Constructs instance segmentation response objects from preprocessed predictions and polygons.

Source code in inference/models/rfdetr/rfdetr.py
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
def make_response(
    self,
    predictions: List[List[List[float]]],
    masks: List[List[List[np.ndarray]]],
    img_dims: List[Tuple[int, int]],
    class_filter: Optional[List[str]] = None,
    *args,
    **kwargs,
) -> List[InstanceSegmentationInferenceResponse]:
    """Constructs instance segmentation response objects from preprocessed predictions and polygons."""
    # Align to actual number of real images; predictions/masks may include padded slots
    if isinstance(img_dims, dict) and "img_dims" in img_dims:
        img_dims = img_dims["img_dims"]
    effective_len = min(len(img_dims), len(predictions), len(masks))

    responses = []
    for ind in range(effective_len):
        batch_predictions = predictions[ind]
        batch_masks = masks[ind]
        preds_out = []
        for pred, mask in zip(batch_predictions, batch_masks):
            if class_filter and self.class_names[int(pred[6])] not in class_filter:
                continue
            preds_out.append(
                InstanceSegmentationPrediction(
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "confidence": pred[4],
                        "class": self.class_names[int(pred[6])],
                        "class_id": int(pred[6]),
                        "points": [Point(x=point[0], y=point[1]) for point in mask],
                    }
                )
            )
        responses.append(
            InstanceSegmentationInferenceResponse(
                predictions=preds_out,
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
        )
    return responses
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session with the RFDETR model.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.

Source code in inference/models/rfdetr/rfdetr.py
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session with the RFDETR model.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )
    bboxes = predictions[0]
    logits = predictions[1]
    masks = predictions[2]

    return (bboxes, logits, masks)

RFDETRObjectDetection

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection with the RFDETR model.

This class is responsible for performing object detection using the RFDETR model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs object detection on the given image using the ONNX session.

Source code in inference/models/rfdetr/rfdetr.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
class RFDETRObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection with the RFDETR model.

    This class is responsible for performing object detection using the RFDETR model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs object detection on the given image using the ONNX session.
    """

    preprocess_means = [0.485, 0.456, 0.406]
    preprocess_stds = [0.229, 0.224, 0.225]

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the RFDETR model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def preproc_image(
        self,
        image: Union[Any, InferenceRequestImage],
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
    ) -> Tuple[np.ndarray, Tuple[int, int]]:
        """
        Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

        Args:
            image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference.
            disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
            disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
            disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
            disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.
        """
        if isinstance(image, Image.Image) and USE_PYTORCH_FOR_PREPROCESSING:
            if CUDA_IS_AVAILABLE:
                np_image = torch.from_numpy(np.asarray(image, copy=False)).cuda()
            else:
                np_image = torch.from_numpy(np.asarray(image, copy=False))
            is_bgr = False
        else:
            np_image, is_bgr = load_image(
                image,
                disable_preproc_auto_orient=disable_preproc_auto_orient
                or "auto-orient" not in self.preproc.keys()
                or DISABLE_PREPROC_AUTO_ORIENT,
            )
        if USE_PYTORCH_FOR_PREPROCESSING:
            if not isinstance(np_image, torch.Tensor):
                np_image = torch.from_numpy(np_image)
            if torch.cuda.is_available():
                np_image = np_image.cuda()

        preprocessed_image, img_dims = self.preprocess_image(
            np_image,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )

        if USE_PYTORCH_FOR_PREPROCESSING:
            preprocessed_image = (
                preprocessed_image.permute(2, 0, 1).unsqueeze(0).contiguous()
            )
            preprocessed_image = preprocessed_image.float()

            preprocessed_image /= 255.0

            means = torch.tensor(
                self.preprocess_means, device=preprocessed_image.device
            ).view(3, 1, 1)
            stds = torch.tensor(
                self.preprocess_stds, device=preprocessed_image.device
            ).view(3, 1, 1)
            preprocessed_image = (preprocessed_image - means) / stds
        else:
            preprocessed_image = preprocessed_image.astype(np.float32)
            preprocessed_image /= 255.0

            preprocessed_image[:, :, 0] = (
                preprocessed_image[:, :, 0] - self.preprocess_means[0]
            ) / self.preprocess_stds[0]
            preprocessed_image[:, :, 1] = (
                preprocessed_image[:, :, 1] - self.preprocess_means[1]
            ) / self.preprocess_stds[1]
            preprocessed_image[:, :, 2] = (
                preprocessed_image[:, :, 2] - self.preprocess_means[2]
            ) / self.preprocess_stds[2]

        if self._needs_nonsquare_preproc:
            intermediate_size = (self._preproc_resize_w, self._preproc_resize_h)
        else:
            intermediate_size = None

        if self.resize_method == "Stretch to":
            if isinstance(preprocessed_image, np.ndarray):
                preprocessed_image = preprocessed_image.astype(np.float32)
                resized = cv2.resize(
                    preprocessed_image,
                    (self.img_size_w, self.img_size_h),
                )
            elif USE_PYTORCH_FOR_PREPROCESSING:
                resized = torch.nn.functional.interpolate(
                    preprocessed_image,
                    size=(self.img_size_h, self.img_size_w),
                    mode="bilinear",
                )
            else:
                raise ValueError(
                    f"Received an image of unknown type, {type(preprocessed_image)}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )

        elif self.resize_method == "Fit (black edges) in":
            resized = letterbox_image(
                preprocessed_image,
                intermediate_size or (self.img_size_w, self.img_size_h),
            )
        elif self.resize_method == "Fit (white edges) in":
            resized = letterbox_image(
                preprocessed_image,
                intermediate_size or (self.img_size_w, self.img_size_h),
                color=(255, 255, 255),
            )
        elif self.resize_method == "Fit (grey edges) in":
            resized = letterbox_image(
                preprocessed_image,
                intermediate_size or (self.img_size_w, self.img_size_h),
                color=(114, 114, 114),
            )

        if intermediate_size is not None:
            if isinstance(resized, np.ndarray):
                resized = cv2.resize(
                    resized.astype(np.float32),
                    (self.img_size_w, self.img_size_h),
                )
            elif USE_PYTORCH_FOR_PREPROCESSING:
                resized = torch.nn.functional.interpolate(
                    resized,
                    size=(self.img_size_h, self.img_size_w),
                    mode="bilinear",
                )
            else:
                raise ValueError(
                    f"Received an image of unknown type, {type(resized)}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )

        if is_bgr:
            if isinstance(resized, np.ndarray):
                resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            else:
                resized = resized[:, [2, 1, 0], :, :]

        if isinstance(resized, np.ndarray):
            img_in = np.transpose(resized, (2, 0, 1))
            img_in = img_in.astype(np.float32)
            img_in = np.expand_dims(img_in, axis=0)
        elif USE_PYTORCH_FOR_PREPROCESSING:
            img_in = resized.float()
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(resized)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )
        return img_in, img_dims

    def preprocess(
        self,
        image: Any,
        disable_preproc_auto_orient: bool = False,
        disable_preproc_contrast: bool = False,
        disable_preproc_grayscale: bool = False,
        disable_preproc_static_crop: bool = False,
        fix_batch_size: bool = False,
        **kwargs,
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        img_in, img_dims = self.load_image(
            image,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
            disable_preproc_contrast=disable_preproc_contrast,
            disable_preproc_grayscale=disable_preproc_grayscale,
            disable_preproc_static_crop=disable_preproc_static_crop,
        )
        if not USE_PYTORCH_FOR_PREPROCESSING:
            img_in = img_in.astype(np.float32)
        else:
            img_in = img_in.float()

        if self.batching_enabled:
            batch_padding = 0
            if FIX_BATCH_SIZE or fix_batch_size:
                if MAX_BATCH_SIZE == float("inf"):
                    logger.warning(
                        "Requested fix_batch_size but MAX_BATCH_SIZE is not set. Using dynamic batching."
                    )
                    batch_padding = 0
                else:
                    batch_padding = MAX_BATCH_SIZE - img_in.shape[0]
            if batch_padding < 0:
                raise ValueError(
                    f"Requested fix_batch_size but passed in {img_in.shape[0]} images "
                    f"when the model's batch size is {MAX_BATCH_SIZE}\n"
                    f"Consider turning off fix_batch_size, changing `MAX_BATCH_SIZE` in"
                    f"your inference server config, or passing at most {MAX_BATCH_SIZE} images at a time"
                )
            else:
                raise ValueError(
                    f"Received an image of unknown type, {type(img_in)}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )

        return img_in, PreprocessReturnMetadata(
            {
                "img_dims": img_dims,
                "disable_preproc_static_crop": disable_preproc_static_crop,
            }
        )

    def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session with the RFDETR model.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        bboxes = predictions[0]
        logits = predictions[1]

        return (bboxes, logits)

    def sigmoid_stable(self, x):
        # More efficient, branchless, numerically stable sigmoid computation
        z = np.exp(-np.abs(x))
        return np.where(x >= 0, 1 / (1 + z), z / (1 + z))

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preproc_return_metadata: PreprocessReturnMetadata,
        confidence: float = DEFAULT_CONFIDENCE,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        **kwargs,
    ) -> List[ObjectDetectionInferenceResponse]:
        bboxes, logits = predictions
        bboxes = bboxes.astype(np.float32)
        logits = logits.astype(np.float32)

        batch_size, num_queries, num_classes = logits.shape
        logits_sigmoid = self.sigmoid_stable(logits)

        img_dims = preproc_return_metadata["img_dims"]

        processed_predictions = []

        for batch_idx in range(batch_size):
            orig_h, orig_w = img_dims[batch_idx]

            logits_flat = logits_sigmoid[batch_idx].reshape(-1)

            # Use argpartition for better performance when max_detections is smaller than logits_flat
            if len(logits_flat) > max_detections:
                partition_indices = np.argpartition(-logits_flat, max_detections)[
                    :max_detections
                ]
                sorted_indices = partition_indices[
                    np.argsort(-logits_flat[partition_indices])
                ]
            else:
                sorted_indices = np.argsort(-logits_flat)
            topk_scores = logits_flat[sorted_indices]

            conf_mask = topk_scores > confidence
            sorted_indices = sorted_indices[conf_mask]
            topk_scores = topk_scores[conf_mask]

            topk_boxes = sorted_indices // num_classes
            topk_labels = sorted_indices % num_classes

            if self.is_one_indexed:
                class_filter_mask = topk_labels != self.background_class_index

                topk_labels[topk_labels > self.background_class_index] -= 1
                topk_scores = topk_scores[class_filter_mask]
                topk_labels = topk_labels[class_filter_mask]
                topk_boxes = topk_boxes[class_filter_mask]

            selected_boxes = bboxes[batch_idx, topk_boxes]

            cxcy = selected_boxes[:, :2]
            wh = selected_boxes[:, 2:]
            xy_min = cxcy - 0.5 * wh
            xy_max = cxcy + 0.5 * wh
            boxes_xyxy = np.concatenate([xy_min, xy_max], axis=1)

            if self.resize_method == "Stretch to":
                scale_fct = np.array([orig_w, orig_h, orig_w, orig_h], dtype=np.float32)
                boxes_xyxy *= scale_fct
            else:
                if self._needs_nonsquare_preproc:
                    input_h, input_w = self._preproc_resize_h, self._preproc_resize_w
                else:
                    input_h, input_w = self.img_size_h, self.img_size_w

                scale = min(input_w / orig_w, input_h / orig_h)
                scaled_w = int(orig_w * scale)
                scaled_h = int(orig_h * scale)

                pad_x = (input_w - scaled_w) / 2
                pad_y = (input_h - scaled_h) / 2

                boxes_input = boxes_xyxy * np.array(
                    [input_w, input_h, input_w, input_h], dtype=np.float32
                )

                boxes_input[:, 0] -= pad_x
                boxes_input[:, 1] -= pad_y
                boxes_input[:, 2] -= pad_x
                boxes_input[:, 3] -= pad_y

                boxes_xyxy = boxes_input / scale

            np.clip(
                boxes_xyxy,
                [0, 0, 0, 0],
                [orig_w, orig_h, orig_w, orig_h],
                out=boxes_xyxy,
            )

            batch_predictions = np.column_stack(
                (
                    boxes_xyxy,
                    topk_scores,
                    np.zeros((len(topk_scores), 1), dtype=np.float32),
                    topk_labels,
                )
            )
            batch_predictions = batch_predictions[
                batch_predictions[:, 6] < len(self.class_names)
            ]

            processed_predictions.append(batch_predictions)

        res = self.make_response(processed_predictions, img_dims, **kwargs)
        return res

    def initialize_model(self, **kwargs) -> None:
        """Initializes the ONNX model, setting up the inference session and other necessary properties."""
        logger.debug("Getting model artefacts")
        self.get_model_artifacts(**kwargs)

        input_resolution = self.environment.get("RESOLUTION")
        if input_resolution is None:
            input_resolution = self.preproc.get("resize", {}).get("width")
        if isinstance(input_resolution, (list, tuple)):
            input_resolution = input_resolution[0]
        try:
            input_resolution = int(input_resolution)
        except (TypeError, ValueError):
            input_resolution = None
        if (
            input_resolution is not None
            and input_resolution >= RFDETR_ONNX_MAX_RESOLUTION
        ):
            logger.error(
                "NOT loading '%s' model, input resolution is '%s', ONNX max resolution limit set to '%s' (limit can be increased via RFDETR_ONNX_MAX_RESOLUTION env variable)",
                self.endpoint,
                input_resolution,
                RFDETR_ONNX_MAX_RESOLUTION,
            )
            raise CannotInitialiseModelError(f"Resolution too high for RFDETR")

        logger.debug("Creating inference session")
        if self.load_weights or not self.has_model_metadata:
            t1_session = perf_counter()
            providers = get_onnxruntime_execution_providers(
                ONNXRUNTIME_EXECUTION_PROVIDERS
            )

            if not self.load_weights:
                providers = [
                    "CPUExecutionProvider"
                ]  # "OpenVINOExecutionProvider" dropped until further investigation is done

            try:
                session_options = onnxruntime.SessionOptions()
                session_options.log_severity_level = 3
                # TensorRT does better graph optimization for its EP than onnx
                if has_trt(providers):
                    session_options.graph_optimization_level = (
                        onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
                    )
                expanded_execution_providers = []
                for ep in self.onnxruntime_execution_providers:
                    if ep == "TensorrtExecutionProvider":
                        ep = (
                            "TensorrtExecutionProvider",
                            {
                                "trt_max_workspace_size": str(1 << 30),
                                "trt_engine_cache_enable": True,
                                "trt_engine_cache_path": os.path.join(
                                    TENSORRT_CACHE_PATH, self.endpoint
                                ),
                                "trt_fp16_enable": True,
                                "trt_dump_subgraphs": False,
                                "trt_force_sequential_engine_build": False,
                                "trt_dla_enable": False,
                            },
                        )
                    expanded_execution_providers.append(ep)

                if "OpenVINOExecutionProvider" in expanded_execution_providers:
                    expanded_execution_providers.remove("OpenVINOExecutionProvider")

                self.onnx_session = onnxruntime.InferenceSession(
                    self.cache_file(self.weights_file),
                    providers=expanded_execution_providers,
                    sess_options=session_options,
                )
            except Exception as e:
                self.clear_cache(delete_from_disk=DISK_CACHE_CLEANUP)
                raise ModelArtefactError(
                    f"Unable to load ONNX session. Cause: {e}"
                ) from e
            logger.debug(f"Session created in {perf_counter() - t1_session} seconds")

            inputs = self.onnx_session.get_inputs()[0]
            input_shape = inputs.shape
            self.batch_size = input_shape[0]
            self.img_size_h = input_shape[2]
            self.img_size_w = input_shape[3]
            self.input_name = inputs.name
            if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str):
                if "resize" in self.preproc:
                    self.img_size_h = int(self.preproc["resize"]["height"])
                    self.img_size_w = int(self.preproc["resize"]["width"])
                else:
                    self.img_size_h = 640
                    self.img_size_w = 640

            if isinstance(self.batch_size, str):
                self.batching_enabled = True
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching enabled"
                )
            else:
                self.batching_enabled = False
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching disabled"
                )

            model_metadata = {
                "batch_size": self.batch_size,
                "img_size_h": self.img_size_h,
                "img_size_w": self.img_size_w,
            }
            logger.debug(f"Writing model metadata to memcache")
            self.write_model_metadata_to_memcache(model_metadata)
            if not self.load_weights:  # had to load weights to get metadata
                del self.onnx_session
        else:
            if not self.has_model_metadata:
                raise ValueError(
                    "This should be unreachable, should get weights if we don't have model metadata"
                )
            logger.debug(f"Loading model metadata from memcache")
            metadata = self.model_metadata_from_memcache()
            self.batch_size = metadata["batch_size"]
            self.img_size_h = metadata["img_size_h"]
            self.img_size_w = metadata["img_size_w"]
            if isinstance(self.batch_size, str):
                self.batching_enabled = True
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching enabled"
                )
            else:
                self.batching_enabled = False
                logger.debug(
                    f"Model {self.endpoint} is loaded with dynamic batching disabled"
                )

        self._needs_nonsquare_preproc = False
        if self.preproc.get("resize"):
            preproc_w = int(self.preproc["resize"].get("width", self.img_size_w))
            preproc_h = int(self.preproc["resize"].get("height", self.img_size_h))
            self._needs_nonsquare_preproc = (
                self.resize_method != "Stretch to"
                and preproc_w != preproc_h
                and self.img_size_h == self.img_size_w
            )
            if self._needs_nonsquare_preproc:
                self._preproc_resize_w = preproc_w
                self._preproc_resize_h = preproc_h
                logger.debug(
                    "Non-square preprocessing detected: resize to %dx%d then stretch to %dx%d",
                    preproc_w,
                    preproc_h,
                    self.img_size_w,
                    self.img_size_h,
                )

        if ROBOFLOW_BACKGROUND_CLASS in self.class_names:
            self.is_one_indexed = True
            self.background_class_index = self.class_names.index(
                ROBOFLOW_BACKGROUND_CLASS
            )
            self.class_names = (
                self.class_names[: self.background_class_index]
                + self.class_names[self.background_class_index + 1 :]
            )
        else:
            self.is_one_indexed = False
        logger.debug("Model initialisation finished.")

    def validate_model_classes(self) -> None:
        pass
Attributes
weights_file property
weights_file

Gets the weights file for the RFDETR model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
initialize_model
initialize_model(**kwargs)

Initializes the ONNX model, setting up the inference session and other necessary properties.

Source code in inference/models/rfdetr/rfdetr.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def initialize_model(self, **kwargs) -> None:
    """Initializes the ONNX model, setting up the inference session and other necessary properties."""
    logger.debug("Getting model artefacts")
    self.get_model_artifacts(**kwargs)

    input_resolution = self.environment.get("RESOLUTION")
    if input_resolution is None:
        input_resolution = self.preproc.get("resize", {}).get("width")
    if isinstance(input_resolution, (list, tuple)):
        input_resolution = input_resolution[0]
    try:
        input_resolution = int(input_resolution)
    except (TypeError, ValueError):
        input_resolution = None
    if (
        input_resolution is not None
        and input_resolution >= RFDETR_ONNX_MAX_RESOLUTION
    ):
        logger.error(
            "NOT loading '%s' model, input resolution is '%s', ONNX max resolution limit set to '%s' (limit can be increased via RFDETR_ONNX_MAX_RESOLUTION env variable)",
            self.endpoint,
            input_resolution,
            RFDETR_ONNX_MAX_RESOLUTION,
        )
        raise CannotInitialiseModelError(f"Resolution too high for RFDETR")

    logger.debug("Creating inference session")
    if self.load_weights or not self.has_model_metadata:
        t1_session = perf_counter()
        providers = get_onnxruntime_execution_providers(
            ONNXRUNTIME_EXECUTION_PROVIDERS
        )

        if not self.load_weights:
            providers = [
                "CPUExecutionProvider"
            ]  # "OpenVINOExecutionProvider" dropped until further investigation is done

        try:
            session_options = onnxruntime.SessionOptions()
            session_options.log_severity_level = 3
            # TensorRT does better graph optimization for its EP than onnx
            if has_trt(providers):
                session_options.graph_optimization_level = (
                    onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
                )
            expanded_execution_providers = []
            for ep in self.onnxruntime_execution_providers:
                if ep == "TensorrtExecutionProvider":
                    ep = (
                        "TensorrtExecutionProvider",
                        {
                            "trt_max_workspace_size": str(1 << 30),
                            "trt_engine_cache_enable": True,
                            "trt_engine_cache_path": os.path.join(
                                TENSORRT_CACHE_PATH, self.endpoint
                            ),
                            "trt_fp16_enable": True,
                            "trt_dump_subgraphs": False,
                            "trt_force_sequential_engine_build": False,
                            "trt_dla_enable": False,
                        },
                    )
                expanded_execution_providers.append(ep)

            if "OpenVINOExecutionProvider" in expanded_execution_providers:
                expanded_execution_providers.remove("OpenVINOExecutionProvider")

            self.onnx_session = onnxruntime.InferenceSession(
                self.cache_file(self.weights_file),
                providers=expanded_execution_providers,
                sess_options=session_options,
            )
        except Exception as e:
            self.clear_cache(delete_from_disk=DISK_CACHE_CLEANUP)
            raise ModelArtefactError(
                f"Unable to load ONNX session. Cause: {e}"
            ) from e
        logger.debug(f"Session created in {perf_counter() - t1_session} seconds")

        inputs = self.onnx_session.get_inputs()[0]
        input_shape = inputs.shape
        self.batch_size = input_shape[0]
        self.img_size_h = input_shape[2]
        self.img_size_w = input_shape[3]
        self.input_name = inputs.name
        if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str):
            if "resize" in self.preproc:
                self.img_size_h = int(self.preproc["resize"]["height"])
                self.img_size_w = int(self.preproc["resize"]["width"])
            else:
                self.img_size_h = 640
                self.img_size_w = 640

        if isinstance(self.batch_size, str):
            self.batching_enabled = True
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching enabled"
            )
        else:
            self.batching_enabled = False
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching disabled"
            )

        model_metadata = {
            "batch_size": self.batch_size,
            "img_size_h": self.img_size_h,
            "img_size_w": self.img_size_w,
        }
        logger.debug(f"Writing model metadata to memcache")
        self.write_model_metadata_to_memcache(model_metadata)
        if not self.load_weights:  # had to load weights to get metadata
            del self.onnx_session
    else:
        if not self.has_model_metadata:
            raise ValueError(
                "This should be unreachable, should get weights if we don't have model metadata"
            )
        logger.debug(f"Loading model metadata from memcache")
        metadata = self.model_metadata_from_memcache()
        self.batch_size = metadata["batch_size"]
        self.img_size_h = metadata["img_size_h"]
        self.img_size_w = metadata["img_size_w"]
        if isinstance(self.batch_size, str):
            self.batching_enabled = True
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching enabled"
            )
        else:
            self.batching_enabled = False
            logger.debug(
                f"Model {self.endpoint} is loaded with dynamic batching disabled"
            )

    self._needs_nonsquare_preproc = False
    if self.preproc.get("resize"):
        preproc_w = int(self.preproc["resize"].get("width", self.img_size_w))
        preproc_h = int(self.preproc["resize"].get("height", self.img_size_h))
        self._needs_nonsquare_preproc = (
            self.resize_method != "Stretch to"
            and preproc_w != preproc_h
            and self.img_size_h == self.img_size_w
        )
        if self._needs_nonsquare_preproc:
            self._preproc_resize_w = preproc_w
            self._preproc_resize_h = preproc_h
            logger.debug(
                "Non-square preprocessing detected: resize to %dx%d then stretch to %dx%d",
                preproc_w,
                preproc_h,
                self.img_size_w,
                self.img_size_h,
            )

    if ROBOFLOW_BACKGROUND_CLASS in self.class_names:
        self.is_one_indexed = True
        self.background_class_index = self.class_names.index(
            ROBOFLOW_BACKGROUND_CLASS
        )
        self.class_names = (
            self.class_names[: self.background_class_index]
            + self.class_names[self.background_class_index + 1 :]
        )
    else:
        self.is_one_indexed = False
    logger.debug("Model initialisation finished.")
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session with the RFDETR model.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.

Source code in inference/models/rfdetr/rfdetr.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session with the RFDETR model.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )
    bboxes = predictions[0]
    logits = predictions[1]

    return (bboxes, logits)
preproc_image
preproc_image(
    image,
    disable_preproc_auto_orient=False,
    disable_preproc_contrast=False,
    disable_preproc_grayscale=False,
    disable_preproc_static_crop=False,
)

Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

Parameters:

Name Type Description Default
image Union[Any, InferenceRequestImage]

An object containing information necessary to load the image for inference.

required
disable_preproc_auto_orient bool

If true, the auto orient preprocessing step is disabled for this call. Default is False.

False
disable_preproc_contrast bool

If true, the contrast preprocessing step is disabled for this call. Default is False.

False
disable_preproc_grayscale bool

If true, the grayscale preprocessing step is disabled for this call. Default is False.

False
disable_preproc_static_crop bool

If true, the static crop preprocessing step is disabled for this call. Default is False.

False

Returns:

Type Description
Tuple[ndarray, Tuple[int, int]]

Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.

Source code in inference/models/rfdetr/rfdetr.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def preproc_image(
    self,
    image: Union[Any, InferenceRequestImage],
    disable_preproc_auto_orient: bool = False,
    disable_preproc_contrast: bool = False,
    disable_preproc_grayscale: bool = False,
    disable_preproc_static_crop: bool = False,
) -> Tuple[np.ndarray, Tuple[int, int]]:
    """
    Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions.

    Args:
        image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference.
        disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
        disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False.
        disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
        disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.
    """
    if isinstance(image, Image.Image) and USE_PYTORCH_FOR_PREPROCESSING:
        if CUDA_IS_AVAILABLE:
            np_image = torch.from_numpy(np.asarray(image, copy=False)).cuda()
        else:
            np_image = torch.from_numpy(np.asarray(image, copy=False))
        is_bgr = False
    else:
        np_image, is_bgr = load_image(
            image,
            disable_preproc_auto_orient=disable_preproc_auto_orient
            or "auto-orient" not in self.preproc.keys()
            or DISABLE_PREPROC_AUTO_ORIENT,
        )
    if USE_PYTORCH_FOR_PREPROCESSING:
        if not isinstance(np_image, torch.Tensor):
            np_image = torch.from_numpy(np_image)
        if torch.cuda.is_available():
            np_image = np_image.cuda()

    preprocessed_image, img_dims = self.preprocess_image(
        np_image,
        disable_preproc_contrast=disable_preproc_contrast,
        disable_preproc_grayscale=disable_preproc_grayscale,
        disable_preproc_static_crop=disable_preproc_static_crop,
    )

    if USE_PYTORCH_FOR_PREPROCESSING:
        preprocessed_image = (
            preprocessed_image.permute(2, 0, 1).unsqueeze(0).contiguous()
        )
        preprocessed_image = preprocessed_image.float()

        preprocessed_image /= 255.0

        means = torch.tensor(
            self.preprocess_means, device=preprocessed_image.device
        ).view(3, 1, 1)
        stds = torch.tensor(
            self.preprocess_stds, device=preprocessed_image.device
        ).view(3, 1, 1)
        preprocessed_image = (preprocessed_image - means) / stds
    else:
        preprocessed_image = preprocessed_image.astype(np.float32)
        preprocessed_image /= 255.0

        preprocessed_image[:, :, 0] = (
            preprocessed_image[:, :, 0] - self.preprocess_means[0]
        ) / self.preprocess_stds[0]
        preprocessed_image[:, :, 1] = (
            preprocessed_image[:, :, 1] - self.preprocess_means[1]
        ) / self.preprocess_stds[1]
        preprocessed_image[:, :, 2] = (
            preprocessed_image[:, :, 2] - self.preprocess_means[2]
        ) / self.preprocess_stds[2]

    if self._needs_nonsquare_preproc:
        intermediate_size = (self._preproc_resize_w, self._preproc_resize_h)
    else:
        intermediate_size = None

    if self.resize_method == "Stretch to":
        if isinstance(preprocessed_image, np.ndarray):
            preprocessed_image = preprocessed_image.astype(np.float32)
            resized = cv2.resize(
                preprocessed_image,
                (self.img_size_w, self.img_size_h),
            )
        elif USE_PYTORCH_FOR_PREPROCESSING:
            resized = torch.nn.functional.interpolate(
                preprocessed_image,
                size=(self.img_size_h, self.img_size_w),
                mode="bilinear",
            )
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(preprocessed_image)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

    elif self.resize_method == "Fit (black edges) in":
        resized = letterbox_image(
            preprocessed_image,
            intermediate_size or (self.img_size_w, self.img_size_h),
        )
    elif self.resize_method == "Fit (white edges) in":
        resized = letterbox_image(
            preprocessed_image,
            intermediate_size or (self.img_size_w, self.img_size_h),
            color=(255, 255, 255),
        )
    elif self.resize_method == "Fit (grey edges) in":
        resized = letterbox_image(
            preprocessed_image,
            intermediate_size or (self.img_size_w, self.img_size_h),
            color=(114, 114, 114),
        )

    if intermediate_size is not None:
        if isinstance(resized, np.ndarray):
            resized = cv2.resize(
                resized.astype(np.float32),
                (self.img_size_w, self.img_size_h),
            )
        elif USE_PYTORCH_FOR_PREPROCESSING:
            resized = torch.nn.functional.interpolate(
                resized,
                size=(self.img_size_h, self.img_size_w),
                mode="bilinear",
            )
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(resized)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

    if is_bgr:
        if isinstance(resized, np.ndarray):
            resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
        else:
            resized = resized[:, [2, 1, 0], :, :]

    if isinstance(resized, np.ndarray):
        img_in = np.transpose(resized, (2, 0, 1))
        img_in = img_in.astype(np.float32)
        img_in = np.expand_dims(img_in, axis=0)
    elif USE_PYTORCH_FOR_PREPROCESSING:
        img_in = resized.float()
    else:
        raise ValueError(
            f"Received an image of unknown type, {type(resized)}; "
            "This is most likely a bug. Contact Roboflow team through github issues "
            "(https://github.com/roboflow/inference/issues) providing full context of the problem"
        )
    return img_in, img_dims

Functions

models/sam

inference.models.sam.segment_anything

Classes

SegmentAnything

Bases: RoboflowCoreModel

SegmentAnything class for handling segmentation tasks.

Attributes:

Name Type Description
sam

The segmentation model.

predictor

The predictor for the segmentation model.

ort_session

ONNX runtime inference session.

embedding_cache

Cache for embeddings.

image_size_cache

Cache for image sizes.

embedding_cache_keys

Keys for the embedding cache.

low_res_logits_cache

Cache for low resolution logits.

segmentation_cache_keys

Keys for the segmentation cache.

Source code in inference/models/sam/segment_anything.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class SegmentAnything(RoboflowCoreModel):
    """SegmentAnything class for handling segmentation tasks.

    Attributes:
        sam: The segmentation model.
        predictor: The predictor for the segmentation model.
        ort_session: ONNX runtime inference session.
        embedding_cache: Cache for embeddings.
        image_size_cache: Cache for image sizes.
        embedding_cache_keys: Keys for the embedding cache.
        low_res_logits_cache: Cache for low resolution logits.
        segmentation_cache_keys: Keys for the segmentation cache.
    """

    def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs):
        """Initializes the SegmentAnything.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, model_id=model_id, **kwargs)
        self.sam = sam_model_registry[self.version_id](
            checkpoint=self.cache_file("encoder.pth")
        )
        self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = SamPredictor(self.sam)
        self.ort_session = onnxruntime.InferenceSession(
            self.cache_file("decoder.onnx"),
            providers=[
                "CUDAExecutionProvider",
                "OpenVINOExecutionProvider",
                "CPUExecutionProvider",
            ],
        )
        self._state_lock = Lock()
        self.embedding_cache = {}
        self.image_size_cache = {}
        self.embedding_cache_keys = []

        self.low_res_logits_cache = {}
        self.segmentation_cache_keys = []
        self.task_type = "unsupervised-segmentation"

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: List of file names.
        """
        return ["encoder.pth", "decoder.onnx"]

    def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs):
        """
        Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
        the cached result will be returned.

        Args:
            image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
            image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                      with this ID. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                               and the second element is the shape (height, width) of the processed image.

        Notes:
            - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.

        Example:
            >>> img_array = ... # some image array
            >>> embed_image(img_array, image_id="sample123")
            (array([...]), (224, 224))
        """
        if image_id and image_id in self.embedding_cache:
            return (
                self.embedding_cache[image_id],
                self.image_size_cache[image_id],
            )
        img_in = self.preproc_image(image)
        self.predictor.set_image(img_in)
        embedding = self.predictor.get_image_embedding().cpu().numpy()
        if image_id:
            self.embedding_cache[image_id] = embedding
            self.image_size_cache[image_id] = img_in.shape[:2]
            self.embedding_cache_keys.append(image_id)
            if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
                cache_key = self.embedding_cache_keys.pop(0)
                del self.embedding_cache[cache_key]
                del self.image_size_cache[cache_key]
        return (embedding, img_in.shape[:2])

    def infer_from_request(self, request: SamInferenceRequest):
        """Performs inference based on the request type.

        Args:
            request (SamInferenceRequest): The inference request.

        Returns:
            Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
        """
        with self._state_lock:
            t1 = perf_counter()
            if isinstance(request, SamEmbeddingRequest):
                embedding, _ = self.embed_image(**request.dict())
                inference_time = perf_counter() - t1
                if request.format == "json":
                    return SamEmbeddingResponse(
                        embeddings=embedding.tolist(), time=inference_time
                    )
                elif request.format == "binary":
                    binary_vector = BytesIO()
                    np.save(binary_vector, embedding)
                    binary_vector.seek(0)
                    return SamEmbeddingResponse(
                        embeddings=binary_vector.getvalue(), time=inference_time
                    )
            elif isinstance(request, SamSegmentationRequest):
                masks, low_res_masks = self.segment_image(**request.dict())
                if request.format == "json":
                    masks = masks > self.predictor.model.mask_threshold
                    masks = masks2poly(masks)
                    low_res_masks = low_res_masks > self.predictor.model.mask_threshold
                    low_res_masks = masks2poly(low_res_masks)
                elif request.format == "binary":
                    binary_vector = BytesIO()
                    np.savez_compressed(
                        binary_vector, masks=masks, low_res_masks=low_res_masks
                    )
                    binary_vector.seek(0)
                    binary_data = binary_vector.getvalue()
                    return binary_data
                else:
                    raise ValueError(f"Invalid format {request.format}")

                response = SamSegmentationResponse(
                    masks=[m.tolist() for m in masks],
                    low_res_masks=[m.tolist() for m in low_res_masks],
                    time=perf_counter() - t1,
                )
                return response

    def preproc_image(self, image: InferenceRequestImage):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image

    def segment_image(
        self,
        image: Any,
        embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None,
        embeddings_format: Optional[str] = "json",
        has_mask_input: Optional[bool] = False,
        image_id: Optional[str] = None,
        mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
        mask_input_format: Optional[str] = "json",
        orig_im_size: Optional[List[int]] = None,
        point_coords: Optional[List[List[float]]] = [],
        point_labels: Optional[List[int]] = [],
        use_mask_input_cache: Optional[bool] = True,
        **kwargs,
    ):
        """
        Segments an image based on provided embeddings, points, masks, or cached results.
        If embeddings are not directly provided, the function can derive them from the input image or cache.

        Args:
            image (Any): The image to be segmented.
            embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image.
                Defaults to None, in which case the image is used to compute embeddings.
            embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'.
            has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False.
            image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
            mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image.
            mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'.
            orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly.
            point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list.
            point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list.
            use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image
                                          and the second element is the low resolution segmentation masks.

        Raises:
            ValueError: If necessary inputs are missing or inconsistent.

        Notes:
            - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
              on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.
        """
        if not embeddings:
            if not image and not image_id:
                raise ValueError(
                    "Must provide either image, cached image_id, or embeddings"
                )
            elif image_id and not image and image_id not in self.embedding_cache:
                raise ValueError(
                    f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
                )
            embedding, original_image_size = self.embed_image(
                image=image, image_id=image_id
            )
        else:
            if not orig_im_size:
                raise ValueError(
                    "Must provide original image size if providing embeddings"
                )
            original_image_size = orig_im_size
            if embeddings_format == "json":
                embedding = np.array(embeddings)
            elif embeddings_format == "binary":
                embedding = np.load(BytesIO(embeddings))

        point_coords = point_coords
        point_coords.append([0, 0])
        point_coords = np.array(point_coords, dtype=np.float32)
        point_coords = np.expand_dims(point_coords, axis=0)
        point_coords = self.predictor.transform.apply_coords(
            point_coords,
            original_image_size,
        )

        point_labels = point_labels
        point_labels.append(-1)
        point_labels = np.array(point_labels, dtype=np.float32)
        point_labels = np.expand_dims(point_labels, axis=0)

        if has_mask_input:
            if (
                image_id
                and image_id in self.low_res_logits_cache
                and use_mask_input_cache
            ):
                mask_input = self.low_res_logits_cache[image_id]
            elif not mask_input and (
                not image_id or image_id not in self.low_res_logits_cache
            ):
                raise ValueError("Must provide either mask_input or cached image_id")
            else:
                if mask_input_format == "json":
                    polys = mask_input
                    mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8)
                    for i, poly in enumerate(polys):
                        poly = ShapelyPolygon(poly)
                        raster = rasterio.features.rasterize(
                            [poly], out_shape=(256, 256)
                        )
                        mask_input[0, i, :, :] = raster
                elif mask_input_format == "binary":
                    binary_data = base64.b64decode(mask_input)
                    mask_input = np.load(BytesIO(binary_data))
        else:
            mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)

        ort_inputs = {
            "image_embeddings": embedding.astype(np.float32),
            "point_coords": point_coords.astype(np.float32),
            "point_labels": point_labels,
            "mask_input": mask_input.astype(np.float32),
            "has_mask_input": (
                np.zeros(1, dtype=np.float32)
                if not has_mask_input
                else np.ones(1, dtype=np.float32)
            ),
            "orig_im_size": np.array(original_image_size, dtype=np.float32),
        }
        masks, _, low_res_logits = self.ort_session.run(None, ort_inputs)
        if image_id:
            self.low_res_logits_cache[image_id] = low_res_logits
            if image_id not in self.segmentation_cache_keys:
                self.segmentation_cache_keys.append(image_id)
            if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
                cache_key = self.segmentation_cache_keys.pop(0)
                del self.low_res_logits_cache[cache_key]
        masks = masks[0]
        low_res_masks = low_res_logits[0]

        return masks, low_res_masks
Functions
__init__
__init__(*args, model_id=f'sam/{SAM_VERSION_ID}', **kwargs)

Initializes the SegmentAnything.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/sam/segment_anything.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs):
    """Initializes the SegmentAnything.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, model_id=model_id, **kwargs)
    self.sam = sam_model_registry[self.version_id](
        checkpoint=self.cache_file("encoder.pth")
    )
    self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
    self.predictor = SamPredictor(self.sam)
    self.ort_session = onnxruntime.InferenceSession(
        self.cache_file("decoder.onnx"),
        providers=[
            "CUDAExecutionProvider",
            "OpenVINOExecutionProvider",
            "CPUExecutionProvider",
        ],
    )
    self._state_lock = Lock()
    self.embedding_cache = {}
    self.image_size_cache = {}
    self.embedding_cache_keys = []

    self.low_res_logits_cache = {}
    self.segmentation_cache_keys = []
    self.task_type = "unsupervised-segmentation"
embed_image
embed_image(image, image_id=None, **kwargs)

Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, the cached result will be returned.

Parameters:

Name Type Description Default
image Any

The image to be embedded. The format should be compatible with the preproc_image method.

required
image_id Optional[str]

An identifier for the image. If provided, the embedding result will be cached with this ID. Defaults to None.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image and the second element is the shape (height, width) of the processed image.

Notes
  • Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Example

img_array = ... # some image array embed_image(img_array, image_id="sample123") (array([...]), (224, 224))

Source code in inference/models/sam/segment_anything.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs):
    """
    Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
    the cached result will be returned.

    Args:
        image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
        image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                  with this ID. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                           and the second element is the shape (height, width) of the processed image.

    Notes:
        - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
        - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.

    Example:
        >>> img_array = ... # some image array
        >>> embed_image(img_array, image_id="sample123")
        (array([...]), (224, 224))
    """
    if image_id and image_id in self.embedding_cache:
        return (
            self.embedding_cache[image_id],
            self.image_size_cache[image_id],
        )
    img_in = self.preproc_image(image)
    self.predictor.set_image(img_in)
    embedding = self.predictor.get_image_embedding().cpu().numpy()
    if image_id:
        self.embedding_cache[image_id] = embedding
        self.image_size_cache[image_id] = img_in.shape[:2]
        self.embedding_cache_keys.append(image_id)
        if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
            cache_key = self.embedding_cache_keys.pop(0)
            del self.embedding_cache[cache_key]
            del self.image_size_cache[cache_key]
    return (embedding, img_in.shape[:2])
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Returns:

Type Description
List[str]

List[str]: List of file names.

Source code in inference/models/sam/segment_anything.py
74
75
76
77
78
79
80
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference.

    Returns:
        List[str]: List of file names.
    """
    return ["encoder.pth", "decoder.onnx"]
infer_from_request
infer_from_request(request)

Performs inference based on the request type.

Parameters:

Name Type Description Default
request SamInferenceRequest

The inference request.

required

Returns:

Type Description

Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.

Source code in inference/models/sam/segment_anything.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def infer_from_request(self, request: SamInferenceRequest):
    """Performs inference based on the request type.

    Args:
        request (SamInferenceRequest): The inference request.

    Returns:
        Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
    """
    with self._state_lock:
        t1 = perf_counter()
        if isinstance(request, SamEmbeddingRequest):
            embedding, _ = self.embed_image(**request.dict())
            inference_time = perf_counter() - t1
            if request.format == "json":
                return SamEmbeddingResponse(
                    embeddings=embedding.tolist(), time=inference_time
                )
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.save(binary_vector, embedding)
                binary_vector.seek(0)
                return SamEmbeddingResponse(
                    embeddings=binary_vector.getvalue(), time=inference_time
                )
        elif isinstance(request, SamSegmentationRequest):
            masks, low_res_masks = self.segment_image(**request.dict())
            if request.format == "json":
                masks = masks > self.predictor.model.mask_threshold
                masks = masks2poly(masks)
                low_res_masks = low_res_masks > self.predictor.model.mask_threshold
                low_res_masks = masks2poly(low_res_masks)
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.savez_compressed(
                    binary_vector, masks=masks, low_res_masks=low_res_masks
                )
                binary_vector.seek(0)
                binary_data = binary_vector.getvalue()
                return binary_data
            else:
                raise ValueError(f"Invalid format {request.format}")

            response = SamSegmentationResponse(
                masks=[m.tolist() for m in masks],
                low_res_masks=[m.tolist() for m in low_res_masks],
                time=perf_counter() - t1,
            )
            return response
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/sam/segment_anything.py
175
176
177
178
179
180
181
182
183
184
185
def preproc_image(self, image: InferenceRequestImage):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    np_image = load_image_rgb(image)
    return np_image
segment_image
segment_image(
    image,
    embeddings=None,
    embeddings_format="json",
    has_mask_input=False,
    image_id=None,
    mask_input=None,
    mask_input_format="json",
    orig_im_size=None,
    point_coords=[],
    point_labels=[],
    use_mask_input_cache=True,
    **kwargs
)

Segments an image based on provided embeddings, points, masks, or cached results. If embeddings are not directly provided, the function can derive them from the input image or cache.

Parameters:

Name Type Description Default
image Any

The image to be segmented.

required
embeddings Optional[Union[ndarray, List[List[float]]]]

The embeddings of the image. Defaults to None, in which case the image is used to compute embeddings.

None
embeddings_format Optional[str]

Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'.

'json'
has_mask_input Optional[bool]

Specifies whether mask input is provided. Defaults to False.

False
image_id Optional[str]

A cached identifier for the image. Useful for accessing cached embeddings or masks.

None
mask_input Optional[Union[ndarray, List[List[List[float]]]]]

Input mask for the image.

None
mask_input_format Optional[str]

Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'.

'json'
orig_im_size Optional[List[int]]

Original size of the image when providing embeddings directly.

None
point_coords Optional[List[List[float]]]

Coordinates of points in the image. Defaults to an empty list.

[]
point_labels Optional[List[int]]

Labels associated with the provided points. Defaults to an empty list.

[]
use_mask_input_cache Optional[bool]

Flag to determine if cached mask input should be used. Defaults to True.

True
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image and the second element is the low resolution segmentation masks.

Raises:

Type Description
ValueError

If necessary inputs are missing or inconsistent.

Notes
  • Embeddings, segmentations, and low-resolution logits can be cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Source code in inference/models/sam/segment_anything.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def segment_image(
    self,
    image: Any,
    embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None,
    embeddings_format: Optional[str] = "json",
    has_mask_input: Optional[bool] = False,
    image_id: Optional[str] = None,
    mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
    mask_input_format: Optional[str] = "json",
    orig_im_size: Optional[List[int]] = None,
    point_coords: Optional[List[List[float]]] = [],
    point_labels: Optional[List[int]] = [],
    use_mask_input_cache: Optional[bool] = True,
    **kwargs,
):
    """
    Segments an image based on provided embeddings, points, masks, or cached results.
    If embeddings are not directly provided, the function can derive them from the input image or cache.

    Args:
        image (Any): The image to be segmented.
        embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image.
            Defaults to None, in which case the image is used to compute embeddings.
        embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'.
        has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False.
        image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
        mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image.
        mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'.
        orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly.
        point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list.
        point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list.
        use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image
                                      and the second element is the low resolution segmentation masks.

    Raises:
        ValueError: If necessary inputs are missing or inconsistent.

    Notes:
        - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
          on repeated requests for the same image.
        - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.
    """
    if not embeddings:
        if not image and not image_id:
            raise ValueError(
                "Must provide either image, cached image_id, or embeddings"
            )
        elif image_id and not image and image_id not in self.embedding_cache:
            raise ValueError(
                f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
            )
        embedding, original_image_size = self.embed_image(
            image=image, image_id=image_id
        )
    else:
        if not orig_im_size:
            raise ValueError(
                "Must provide original image size if providing embeddings"
            )
        original_image_size = orig_im_size
        if embeddings_format == "json":
            embedding = np.array(embeddings)
        elif embeddings_format == "binary":
            embedding = np.load(BytesIO(embeddings))

    point_coords = point_coords
    point_coords.append([0, 0])
    point_coords = np.array(point_coords, dtype=np.float32)
    point_coords = np.expand_dims(point_coords, axis=0)
    point_coords = self.predictor.transform.apply_coords(
        point_coords,
        original_image_size,
    )

    point_labels = point_labels
    point_labels.append(-1)
    point_labels = np.array(point_labels, dtype=np.float32)
    point_labels = np.expand_dims(point_labels, axis=0)

    if has_mask_input:
        if (
            image_id
            and image_id in self.low_res_logits_cache
            and use_mask_input_cache
        ):
            mask_input = self.low_res_logits_cache[image_id]
        elif not mask_input and (
            not image_id or image_id not in self.low_res_logits_cache
        ):
            raise ValueError("Must provide either mask_input or cached image_id")
        else:
            if mask_input_format == "json":
                polys = mask_input
                mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8)
                for i, poly in enumerate(polys):
                    poly = ShapelyPolygon(poly)
                    raster = rasterio.features.rasterize(
                        [poly], out_shape=(256, 256)
                    )
                    mask_input[0, i, :, :] = raster
            elif mask_input_format == "binary":
                binary_data = base64.b64decode(mask_input)
                mask_input = np.load(BytesIO(binary_data))
    else:
        mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)

    ort_inputs = {
        "image_embeddings": embedding.astype(np.float32),
        "point_coords": point_coords.astype(np.float32),
        "point_labels": point_labels,
        "mask_input": mask_input.astype(np.float32),
        "has_mask_input": (
            np.zeros(1, dtype=np.float32)
            if not has_mask_input
            else np.ones(1, dtype=np.float32)
        ),
        "orig_im_size": np.array(original_image_size, dtype=np.float32),
    }
    masks, _, low_res_logits = self.ort_session.run(None, ort_inputs)
    if image_id:
        self.low_res_logits_cache[image_id] = low_res_logits
        if image_id not in self.segmentation_cache_keys:
            self.segmentation_cache_keys.append(image_id)
        if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
            cache_key = self.segmentation_cache_keys.pop(0)
            del self.low_res_logits_cache[cache_key]
    masks = masks[0]
    low_res_masks = low_res_logits[0]

    return masks, low_res_masks

Functions

models/sam2

inference.models.sam2.segment_anything2

Classes

SegmentAnything2

Bases: RoboflowCoreModel

SegmentAnything class for handling segmentation tasks.

Attributes:

Name Type Description
sam

The segmentation model.

embedding_cache

Cache for embeddings.

image_size_cache

Cache for image sizes.

embedding_cache_keys

Keys for the embedding cache.

Source code in inference/models/sam2/segment_anything2.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class SegmentAnything2(RoboflowCoreModel):
    """SegmentAnything class for handling segmentation tasks.

    Attributes:
        sam: The segmentation model.
        embedding_cache: Cache for embeddings.
        image_size_cache: Cache for image sizes.
        embedding_cache_keys: Keys for the embedding cache.

    """

    def __init__(
        self,
        *args,
        model_id: str = f"sam2/{SAM2_VERSION_ID}",
        low_res_logits_cache_size: int = SAM2_MAX_LOGITS_CACHE_SIZE,
        embedding_cache_size: int = SAM2_MAX_EMBEDDING_CACHE_SIZE,
        **kwargs,
    ):
        """Initializes the SegmentAnything.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, model_id=model_id, **kwargs)
        checkpoint = self.cache_file("weights.pt")
        model_cfg = {
            "hiera_large": "sam2_hiera_l.yaml",
            "hiera_small": "sam2_hiera_s.yaml",
            "hiera_tiny": "sam2_hiera_t.yaml",
            "hiera_b_plus": "sam2_hiera_b+.yaml",
        }[self.version_id]

        self.sam = build_sam2(model_cfg, checkpoint, device=DEVICE)
        self.low_res_logits_cache_size = low_res_logits_cache_size
        self.embedding_cache_size = embedding_cache_size

        self.embedding_cache = {}
        self.image_size_cache = {}
        self.embedding_cache_keys = []
        self.low_res_logits_cache: Dict[Tuple[str, str], LogitsCacheType] = {}
        self.low_res_logits_cache_keys = []
        self._state_lock = RLock()
        self.task_type = "unsupervised-segmentation"

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: List of file names.
        """
        return ["weights.pt"]

    def embed_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        **kwargs,
    ):
        """
        Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
        the cached result will be returned.

        Args:
            image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
            image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                      with this ID. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                               and the second element is the shape (height, width) of the processed image.

        Notes:
            - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
            - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.

        Example:
            >>> img_array = ... # some image array
            >>> embed_image(img_array, image_id="sample123")
            (array([...]), (224, 224))
        """
        if image_id:
            embedding_cache_content = self.embedding_cache.get(image_id)
            image_size_content = self.image_size_cache.get(image_id)
            if embedding_cache_content is not None and image_size_content is not None:
                return embedding_cache_content, image_size_content, image_id

        img_in = self.preproc_image(image)
        if image_id is None:
            image_id = hashlib.md5(img_in.tobytes()).hexdigest()[:12]

        embedding_cache_content = self.embedding_cache.get(image_id)
        image_size_content = self.image_size_cache.get(image_id)
        if embedding_cache_content is not None and image_size_content is not None:
            return (
                embedding_cache_content,
                image_size_content,
                image_id,
            )

        with torch.inference_mode():
            with _temporarily_disable_torch_jit_script():
                predictor = SAM2ImagePredictor(self.sam)
            predictor.set_image(img_in)
            embedding_dict = predictor._features

        with self._state_lock:
            self.embedding_cache[image_id] = embedding_dict
            self.image_size_cache[image_id] = img_in.shape[:2]
            safe_remove_from_list(values=self.embedding_cache_keys, element=image_id)
            self.embedding_cache_keys.append(image_id)
            if len(self.embedding_cache_keys) > self.embedding_cache_size:
                cache_key = safe_pop_from_list(values=self.embedding_cache_keys)
                if cache_key is not None:
                    safe_remove_from_dict(values=self.embedding_cache, key=cache_key)
                    safe_remove_from_dict(values=self.image_size_cache, key=cache_key)
            return embedding_dict, img_in.shape[:2], image_id

    @usage_collector("model")
    def infer_from_request(self, request: Sam2InferenceRequest):
        """Performs inference based on the request type.

        Args:
            request (SamInferenceRequest): The inference request.

        Returns:
            Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
        """
        t1 = perf_counter()
        if isinstance(request, Sam2EmbeddingRequest):
            _, _, image_id = self.embed_image(**request.dict())
            inference_time = perf_counter() - t1
            return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
        elif isinstance(request, Sam2SegmentationRequest):
            masks, scores, low_resolution_logits = self.segment_image(**request.dict())

            if request.format == "json":
                return turn_segmentation_results_into_api_response(
                    masks=masks,
                    scores=scores,
                    mask_threshold=0.0,
                    inference_start_timestamp=t1,
                )
            elif request.format == "rle":
                return turn_segmentation_results_into_rle_response(
                    masks=masks,
                    scores=scores,
                    mask_threshold=0.0,
                    inference_start_timestamp=t1,
                )
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.savez_compressed(
                    binary_vector, masks=masks, low_res_masks=low_resolution_logits
                )
                binary_vector.seek(0)
                binary_data = binary_vector.getvalue()
                return binary_data
            else:
                raise ValueError(f"Invalid format {request.format}")

        else:
            raise ValueError(f"Invalid request type {type(request)}")

    def preproc_image(self, image: InferenceRequestImage):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image

    def segment_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        prompts: Optional[Union[Sam2PromptSet, dict]] = None,
        multimask_output: Optional[bool] = True,
        mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
        save_logits_to_cache: bool = False,
        load_logits_from_cache: bool = False,
        **kwargs,
    ):
        """
        Segments an image based on provided embeddings, points, masks, or cached results.
        If embeddings are not directly provided, the function can derive them from the input image or cache.

        Args:
            image (Any): The image to be segmented.
            image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
            prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
            mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
            multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
                promising will be returned
            )
            use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
                - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                    for each prompt element
                - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                    of each prompt element
                - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                    for most confident mask of each prompt element

        Raises:
            ValueError: If necessary inputs are missing or inconsistent.

        Notes:
            - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
              on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.
        """
        load_logits_from_cache = (
            load_logits_from_cache and not DISABLE_SAM2_LOGITS_CACHE
        )
        save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM2_LOGITS_CACHE
        with torch.inference_mode():
            if image is None and not image_id:
                raise ValueError("Must provide either image or  cached image_id")
            elif image_id and image is None and image_id not in self.embedding_cache:
                raise ValueError(
                    f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
                )
            embedding, original_image_size, image_id = self.embed_image(
                image=image, image_id=image_id
            )
            with _temporarily_disable_torch_jit_script():
                predictor = SAM2ImagePredictor(self.sam)
            predictor._is_image_set = True
            predictor._features = embedding
            predictor._orig_hw = [original_image_size]
            predictor._is_batch = False
            args = dict()
            prompt_set: Sam2PromptSet
            if prompts:
                if type(prompts) is dict:
                    prompt_set = Sam2PromptSet(**prompts)
                    args = prompt_set.to_sam2_inputs()
                else:
                    prompt_set = prompts
                    args = prompts.to_sam2_inputs()
            else:
                prompt_set = Sam2PromptSet()

            if mask_input is None and load_logits_from_cache:
                mask_input = maybe_load_low_res_logits_from_cache(
                    image_id, prompt_set, self.low_res_logits_cache
                )

            args = pad_points(args)
            if not any(args.values()):
                args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}
            masks, scores, low_resolution_logits = predictor.predict(
                mask_input=mask_input,
                multimask_output=multimask_output,
                return_logits=True,
                normalize_coords=True,
                **args,
            )
            masks, scores, low_resolution_logits = choose_most_confident_sam_prediction(
                masks=masks,
                scores=scores,
                low_resolution_logits=low_resolution_logits,
            )

            if save_logits_to_cache:
                self.add_low_res_logits_to_cache(
                    low_resolution_logits, image_id, prompt_set
                )

            return masks, scores, low_resolution_logits

    def add_low_res_logits_to_cache(
        self, logits: np.ndarray, image_id: str, prompt_set: Sam2PromptSet
    ) -> None:
        logits = logits[:, None, :, :]
        prompt_id = hash_prompt_set(image_id, prompt_set)
        with self._state_lock:
            self.low_res_logits_cache[prompt_id] = {
                "logits": logits,
                "prompt_set": prompt_set,
            }
            safe_remove_from_list(
                values=self.low_res_logits_cache_keys, element=prompt_id
            )
            self.low_res_logits_cache_keys.append(prompt_id)
            if len(self.low_res_logits_cache_keys) > self.low_res_logits_cache_size:
                cache_key = safe_pop_from_list(values=self.low_res_logits_cache_keys)
                if cache_key is not None:
                    safe_remove_from_dict(
                        values=self.low_res_logits_cache, key=cache_key
                    )
Functions
__init__
__init__(
    *args,
    model_id=f"sam2/{SAM2_VERSION_ID}",
    low_res_logits_cache_size=SAM2_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size=SAM2_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs,
)

Initializes the SegmentAnything.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/sam2/segment_anything2.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(
    self,
    *args,
    model_id: str = f"sam2/{SAM2_VERSION_ID}",
    low_res_logits_cache_size: int = SAM2_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size: int = SAM2_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs,
):
    """Initializes the SegmentAnything.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, model_id=model_id, **kwargs)
    checkpoint = self.cache_file("weights.pt")
    model_cfg = {
        "hiera_large": "sam2_hiera_l.yaml",
        "hiera_small": "sam2_hiera_s.yaml",
        "hiera_tiny": "sam2_hiera_t.yaml",
        "hiera_b_plus": "sam2_hiera_b+.yaml",
    }[self.version_id]

    self.sam = build_sam2(model_cfg, checkpoint, device=DEVICE)
    self.low_res_logits_cache_size = low_res_logits_cache_size
    self.embedding_cache_size = embedding_cache_size

    self.embedding_cache = {}
    self.image_size_cache = {}
    self.embedding_cache_keys = []
    self.low_res_logits_cache: Dict[Tuple[str, str], LogitsCacheType] = {}
    self.low_res_logits_cache_keys = []
    self._state_lock = RLock()
    self.task_type = "unsupervised-segmentation"
embed_image
embed_image(image, image_id=None, **kwargs)

Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, the cached result will be returned.

Parameters:

Name Type Description Default
image Any

The image to be embedded. The format should be compatible with the preproc_image method.

required
image_id Optional[str]

An identifier for the image. If provided, the embedding result will be cached with this ID. Defaults to None.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image and the second element is the shape (height, width) of the processed image.

Notes
  • Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Example

img_array = ... # some image array embed_image(img_array, image_id="sample123") (array([...]), (224, 224))

Source code in inference/models/sam2/segment_anything2.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def embed_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    **kwargs,
):
    """
    Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
    the cached result will be returned.

    Args:
        image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
        image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                  with this ID. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                           and the second element is the shape (height, width) of the processed image.

    Notes:
        - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
        - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.

    Example:
        >>> img_array = ... # some image array
        >>> embed_image(img_array, image_id="sample123")
        (array([...]), (224, 224))
    """
    if image_id:
        embedding_cache_content = self.embedding_cache.get(image_id)
        image_size_content = self.image_size_cache.get(image_id)
        if embedding_cache_content is not None and image_size_content is not None:
            return embedding_cache_content, image_size_content, image_id

    img_in = self.preproc_image(image)
    if image_id is None:
        image_id = hashlib.md5(img_in.tobytes()).hexdigest()[:12]

    embedding_cache_content = self.embedding_cache.get(image_id)
    image_size_content = self.image_size_cache.get(image_id)
    if embedding_cache_content is not None and image_size_content is not None:
        return (
            embedding_cache_content,
            image_size_content,
            image_id,
        )

    with torch.inference_mode():
        with _temporarily_disable_torch_jit_script():
            predictor = SAM2ImagePredictor(self.sam)
        predictor.set_image(img_in)
        embedding_dict = predictor._features

    with self._state_lock:
        self.embedding_cache[image_id] = embedding_dict
        self.image_size_cache[image_id] = img_in.shape[:2]
        safe_remove_from_list(values=self.embedding_cache_keys, element=image_id)
        self.embedding_cache_keys.append(image_id)
        if len(self.embedding_cache_keys) > self.embedding_cache_size:
            cache_key = safe_pop_from_list(values=self.embedding_cache_keys)
            if cache_key is not None:
                safe_remove_from_dict(values=self.embedding_cache, key=cache_key)
                safe_remove_from_dict(values=self.image_size_cache, key=cache_key)
        return embedding_dict, img_in.shape[:2], image_id
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Returns:

Type Description
List[str]

List[str]: List of file names.

Source code in inference/models/sam2/segment_anything2.py
105
106
107
108
109
110
111
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference.

    Returns:
        List[str]: List of file names.
    """
    return ["weights.pt"]
infer_from_request
infer_from_request(request)

Performs inference based on the request type.

Parameters:

Name Type Description Default
request SamInferenceRequest

The inference request.

required

Returns:

Type Description

Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.

Source code in inference/models/sam2/segment_anything2.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@usage_collector("model")
def infer_from_request(self, request: Sam2InferenceRequest):
    """Performs inference based on the request type.

    Args:
        request (SamInferenceRequest): The inference request.

    Returns:
        Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
    """
    t1 = perf_counter()
    if isinstance(request, Sam2EmbeddingRequest):
        _, _, image_id = self.embed_image(**request.dict())
        inference_time = perf_counter() - t1
        return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
    elif isinstance(request, Sam2SegmentationRequest):
        masks, scores, low_resolution_logits = self.segment_image(**request.dict())

        if request.format == "json":
            return turn_segmentation_results_into_api_response(
                masks=masks,
                scores=scores,
                mask_threshold=0.0,
                inference_start_timestamp=t1,
            )
        elif request.format == "rle":
            return turn_segmentation_results_into_rle_response(
                masks=masks,
                scores=scores,
                mask_threshold=0.0,
                inference_start_timestamp=t1,
            )
        elif request.format == "binary":
            binary_vector = BytesIO()
            np.savez_compressed(
                binary_vector, masks=masks, low_res_masks=low_resolution_logits
            )
            binary_vector.seek(0)
            binary_data = binary_vector.getvalue()
            return binary_data
        else:
            raise ValueError(f"Invalid format {request.format}")

    else:
        raise ValueError(f"Invalid request type {type(request)}")
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/sam2/segment_anything2.py
226
227
228
229
230
231
232
233
234
235
236
def preproc_image(self, image: InferenceRequestImage):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    np_image = load_image_rgb(image)
    return np_image
segment_image
segment_image(
    image,
    image_id=None,
    prompts=None,
    multimask_output=True,
    mask_input=None,
    save_logits_to_cache=False,
    load_logits_from_cache=False,
    **kwargs
)

Segments an image based on provided embeddings, points, masks, or cached results. If embeddings are not directly provided, the function can derive them from the input image or cache.

Parameters:

Name Type Description Default
image Any

The image to be segmented.

required
image_id Optional[str]

A cached identifier for the image. Useful for accessing cached embeddings or masks.

None
prompts Optional[List[Sam2Prompt]]

List of prompts to use for segmentation. Defaults to None.

None
mask_input Optional[Union[ndarray, List[List[List[float]]]]]

Input low_res_logits for the image.

None
multimask_output Optional[bool]

(bool): Flag to decide if multiple masks proposal to be predicted (among which the most promising will be returned

True
use_logits_cache

(bool): Flag to decide to use cached logits from prior prompting

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Raises:

Type Description
ValueError

If necessary inputs are missing or inconsistent.

Notes
  • Embeddings, segmentations, and low-resolution logits can be cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Source code in inference/models/sam2/segment_anything2.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def segment_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    prompts: Optional[Union[Sam2PromptSet, dict]] = None,
    multimask_output: Optional[bool] = True,
    mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
    save_logits_to_cache: bool = False,
    load_logits_from_cache: bool = False,
    **kwargs,
):
    """
    Segments an image based on provided embeddings, points, masks, or cached results.
    If embeddings are not directly provided, the function can derive them from the input image or cache.

    Args:
        image (Any): The image to be segmented.
        image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
        prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
        mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
        multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
            promising will be returned
        )
        use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element

    Raises:
        ValueError: If necessary inputs are missing or inconsistent.

    Notes:
        - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
          on repeated requests for the same image.
        - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.
    """
    load_logits_from_cache = (
        load_logits_from_cache and not DISABLE_SAM2_LOGITS_CACHE
    )
    save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM2_LOGITS_CACHE
    with torch.inference_mode():
        if image is None and not image_id:
            raise ValueError("Must provide either image or  cached image_id")
        elif image_id and image is None and image_id not in self.embedding_cache:
            raise ValueError(
                f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
            )
        embedding, original_image_size, image_id = self.embed_image(
            image=image, image_id=image_id
        )
        with _temporarily_disable_torch_jit_script():
            predictor = SAM2ImagePredictor(self.sam)
        predictor._is_image_set = True
        predictor._features = embedding
        predictor._orig_hw = [original_image_size]
        predictor._is_batch = False
        args = dict()
        prompt_set: Sam2PromptSet
        if prompts:
            if type(prompts) is dict:
                prompt_set = Sam2PromptSet(**prompts)
                args = prompt_set.to_sam2_inputs()
            else:
                prompt_set = prompts
                args = prompts.to_sam2_inputs()
        else:
            prompt_set = Sam2PromptSet()

        if mask_input is None and load_logits_from_cache:
            mask_input = maybe_load_low_res_logits_from_cache(
                image_id, prompt_set, self.low_res_logits_cache
            )

        args = pad_points(args)
        if not any(args.values()):
            args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}
        masks, scores, low_resolution_logits = predictor.predict(
            mask_input=mask_input,
            multimask_output=multimask_output,
            return_logits=True,
            normalize_coords=True,
            **args,
        )
        masks, scores, low_resolution_logits = choose_most_confident_sam_prediction(
            masks=masks,
            scores=scores,
            low_resolution_logits=low_resolution_logits,
        )

        if save_logits_to_cache:
            self.add_low_res_logits_to_cache(
                low_resolution_logits, image_id, prompt_set
            )

        return masks, scores, low_resolution_logits

Functions

choose_most_confident_sam_prediction

choose_most_confident_sam_prediction(
    masks, scores, low_resolution_logits
)

This function is supposed to post-process SAM2 inference and choose most confident mask regardless of multimask_output parameter value Args: masks: np array with values 0.0 and 1.0 representing predicted mask of size (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation in SAM2 library, so to handle inference uniformly, we need to compensate with this function. scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending on prompt set size - this array gives confidence score for mask proposal low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits can be passed to a subsequent iteration as mask input. Returns: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Source code in inference/models/sam2/segment_anything2.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def choose_most_confident_sam_prediction(
    masks: np.ndarray,
    scores: np.ndarray,
    low_resolution_logits: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function is supposed to post-process SAM2 inference and choose most confident
    mask regardless of `multimask_output` parameter value
    Args:
        masks: np array with values 0.0 and 1.0 representing predicted mask of size
            (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on
            prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation
            in SAM2 library, so to handle inference uniformly, we need to compensate with
            this function.
        scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending
            on prompt set size - this array gives confidence score for mask proposal
        low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or
            (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits
             can be passed to a subsequent iteration as mask input.
    Returns:
        Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element
    """
    if len(masks.shape) == 3:
        masks = np.expand_dims(masks, axis=0)
        scores = np.expand_dims(scores, axis=0)
        low_resolution_logits = np.expand_dims(low_resolution_logits, axis=0)
    selected_masks, selected_scores, selected_low_resolution_logits = [], [], []
    for mask, score, low_resolution_logit in zip(masks, scores, low_resolution_logits):
        selected_mask, selected_score, selected_low_resolution_logit = (
            choose_most_confident_prompt_set_element_prediction(
                mask=mask,
                score=score,
                low_resolution_logit=low_resolution_logit,
            )
        )
        selected_masks.append(selected_mask)
        selected_scores.append(selected_score)
        selected_low_resolution_logits.append(selected_low_resolution_logit)
    return (
        np.asarray(selected_masks),
        np.asarray(selected_scores),
        np.asarray(selected_low_resolution_logits),
    )

find_prior_prompt_in_cache

find_prior_prompt_in_cache(
    initial_prompt_set, image_id, cache
)

Performs search over the cache to see if prior used prompts are subset of this one.

Source code in inference/models/sam2/segment_anything2.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def find_prior_prompt_in_cache(
    initial_prompt_set: Sam2PromptSet,
    image_id: str,
    cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
    """
    Performs search over the cache to see if prior used prompts are subset of this one.
    """

    logits_for_image = [cache[k] for k in cache if k[0] == image_id]
    maxed_size = 0
    best_match: Optional[np.ndarray] = None
    desired_size = initial_prompt_set.num_points() - 1
    for cached_dict in logits_for_image[::-1]:
        logits = cached_dict["logits"]
        prompt_set: Sam2PromptSet = cached_dict["prompt_set"]
        is_viable = is_prompt_strict_subset(prompt_set, initial_prompt_set)
        if not is_viable:
            continue

        size = prompt_set.num_points()
        # short circuit search if we find prompt with one less point (most recent possible mask)
        if size == desired_size:
            return logits
        if size >= maxed_size:
            maxed_size = size
            best_match = logits

    return best_match

hash_prompt_set

hash_prompt_set(image_id, prompt_set)

Computes unique hash from a prompt set.

Source code in inference/models/sam2/segment_anything2.py
364
365
366
367
368
def hash_prompt_set(image_id: str, prompt_set: Sam2PromptSet) -> Tuple[str, str]:
    """Computes unique hash from a prompt set."""
    md5_hash = hashlib.md5()
    md5_hash.update(str(prompt_set).encode("utf-8"))
    return image_id, md5_hash.hexdigest()[:12]

maybe_load_low_res_logits_from_cache

maybe_load_low_res_logits_from_cache(
    image_id, prompt_set, cache
)

Loads prior masks from the cache by searching over possibel prior prompts.

Source code in inference/models/sam2/segment_anything2.py
371
372
373
374
375
376
377
378
379
380
381
def maybe_load_low_res_logits_from_cache(
    image_id: str,
    prompt_set: Sam2PromptSet,
    cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
    "Loads prior masks from the cache by searching over possibel prior prompts."
    prompts = prompt_set.prompts
    if not prompts:
        return None

    return find_prior_prompt_in_cache(prompt_set, image_id, cache)

pad_points

pad_points(args)

Pad arguments to be passed to sam2 model with not_a_point label (-1). This is necessary when there are multiple prompts per image so that a tensor can be created.

Also pads empty point lists with a dummy non-point entry.

Source code in inference/models/sam2/segment_anything2.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def pad_points(args: Dict[str, Any]) -> Dict[str, Any]:
    """
    Pad arguments to be passed to sam2 model with not_a_point label (-1).
    This is necessary when there are multiple prompts per image so that a tensor can be created.


    Also pads empty point lists with a dummy non-point entry.
    """
    args = copy.deepcopy(args)
    if args["point_coords"] is not None:
        max_len = max(max(len(prompt) for prompt in args["point_coords"]), 1)
        for prompt in args["point_coords"]:
            for _ in range(max_len - len(prompt)):
                prompt.append([0, 0])
        for label in args["point_labels"]:
            for _ in range(max_len - len(label)):
                label.append(-1)
    else:
        if args["point_labels"] is not None:
            raise ValueError(
                "Can't have point labels without corresponding point coordinates"
            )
    return args

inference.models.sam2.segment_anything2_inference_models

Classes

InferenceModelsSAM2Adapter

Bases: Model

SegmentAnything class for handling segmentation tasks.

Attributes:

Name Type Description
sam

The segmentation model.

embedding_cache

Cache for embeddings.

image_size_cache

Cache for image sizes.

embedding_cache_keys

Keys for the embedding cache.

Source code in inference/models/sam2/segment_anything2_inference_models.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class InferenceModelsSAM2Adapter(Model):
    """SegmentAnything class for handling segmentation tasks.

    Attributes:
        sam: The segmentation model.
        embedding_cache: Cache for embeddings.
        image_size_cache: Cache for image sizes.
        embedding_cache_keys: Keys for the embedding cache.

    """

    def __init__(
        self,
        *args,
        model_id: str = f"sam2/{SAM2_VERSION_ID}",
        api_key: Optional[str] = None,
        low_res_logits_cache_size: int = SAM2_MAX_LOGITS_CACHE_SIZE,
        embedding_cache_size: int = SAM2_MAX_EMBEDDING_CACHE_SIZE,
        **kwargs,
    ):
        """Initializes the SegmentAnything.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__()

        self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

        self.api_key = api_key if api_key else API_KEY

        self.task_type = "unsupervised-segmentation"

        sam2_image_embeddings_cache = Sam2ImageEmbeddingsInMemoryCache.init(
            size_limit=embedding_cache_size,
            send_to_cpu=True,
        )
        sam2_low_resolution_masks_cache = Sam2LowResolutionMasksInMemoryCache.init(
            size_limit=low_res_logits_cache_size,
            send_to_cpu=True,
        )
        extra_weights_provider_headers = get_extra_weights_provider_headers(
            countinference=kwargs.get("countinference"),
            service_secret=kwargs.get("service_secret"),
        )
        backend = list(
            VALID_INFERENCE_MODELS_BACKENDS.difference(
                DISABLED_INFERENCE_MODELS_BACKENDS
            )
        )
        self._model: SAM2Torch = AutoModel.from_pretrained(
            model_id_or_path=model_id,
            api_key=self.api_key,
            allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
            allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
            sam2_image_embeddings_cache=sam2_image_embeddings_cache,
            sam2_low_resolution_masks_cache=sam2_low_resolution_masks_cache,
            sam2_allow_client_generated_hash_ids=True,
            weights_provider_extra_headers=extra_weights_provider_headers,
            backend=backend,
            **kwargs,
        )

    @usage_collector("model")
    def infer_from_request(self, request: Sam2InferenceRequest):
        """Performs inference based on the request type.

        Args:
            request (SamInferenceRequest): The inference request.

        Returns:
            Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
        """
        t1 = perf_counter()
        if isinstance(request, Sam2EmbeddingRequest):
            _, _, image_id = self.embed_image(**request.dict())
            inference_time = perf_counter() - t1
            return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
        elif isinstance(request, Sam2SegmentationRequest):
            masks, scores, low_resolution_logits = self.segment_image(**request.dict())

            if request.format == "json":
                return turn_segmentation_results_into_api_response(
                    masks=masks,
                    scores=scores,
                    mask_threshold=MASK_THRESHOLD,
                    inference_start_timestamp=t1,
                )
            elif request.format == "rle":
                return turn_segmentation_results_into_rle_response(
                    masks=masks,
                    scores=scores,
                    mask_threshold=0.0,
                    inference_start_timestamp=t1,
                )
            elif request.format == "binary":
                binary_vector = BytesIO()
                np.savez_compressed(
                    binary_vector, masks=masks, low_res_masks=low_resolution_logits
                )
                binary_vector.seek(0)
                binary_data = binary_vector.getvalue()
                return binary_data
            else:
                raise ValueError(f"Invalid format {request.format}")

        else:
            raise ValueError(f"Invalid request type {type(request)}")

    def embed_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        **kwargs,
    ):
        """
        Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
        the cached result will be returned.

        Args:
            image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
            image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                      with this ID. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                               and the second element is the shape (height, width) of the processed image.

        Notes:
            - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
            - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.

        Example:
            >>> img_array = ... # some image array
            >>> embed_image(img_array, image_id="sample123")
            (array([...]), (224, 224))
        """
        loaded_image = self.preproc_image(image)
        if loaded_image is None:
            raise ValueError("Image must be provided to handle this request.")
        embeddings = self._model.embed_images(
            images=loaded_image, image_hashes=image_id, **kwargs
        )[0]
        embedding_dict = {
            "image_embed": embeddings.embeddings.cpu().numpy(),
            "high_res_feats": [
                f.cpu().numpy() for f in embeddings.high_resolution_features
            ],
        }
        return embedding_dict, embeddings.image_size_hw, embeddings.image_hash

    def preproc_image(self, image: InferenceRequestImage):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        if image is not None:
            return load_image_bgr(image)
        return None

    def segment_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        prompts: Optional[Union[Sam2PromptSet, dict]] = None,
        multimask_output: Optional[bool] = True,
        mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
        save_logits_to_cache: bool = False,
        load_logits_from_cache: bool = False,
        **kwargs,
    ):
        """
        Segments an image based on provided embeddings, points, masks, or cached results.
        If embeddings are not directly provided, the function can derive them from the input image or cache.

        Args:
            image (Any): The image to be segmented.
            image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
            prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
            mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
            multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
                promising will be returned
            )
            use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
                - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                    for each prompt element
                - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                    of each prompt element
                - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                    for most confident mask of each prompt element

        Raises:
            ValueError: If necessary inputs are missing or inconsistent.

        Notes:
            - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
              on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.
        """
        load_logits_from_cache = (
            load_logits_from_cache and not DISABLE_SAM2_LOGITS_CACHE
        )
        save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM2_LOGITS_CACHE
        loaded_image = self.preproc_image(image)
        if prompts is not None:
            if type(prompts) is dict:
                prompts = Sam2PromptSet(**prompts)
        else:
            prompts = Sam2PromptSet()
        args = prompts.to_sam2_inputs()
        args = pad_points(args)
        if not any(args.values()):
            args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}
        if args["point_coords"] is not None:
            args["point_coords"] = np.array(args["point_coords"])
        if args["point_labels"] is not None:
            args["point_labels"] = np.array(args["point_labels"])
        if args["box"] is not None:
            args["box"] = np.array(args["box"])
        if mask_input is not None and isinstance(mask_input, list):
            mask_input = np.array(mask_input)
        prediction = self._model.segment_images(
            images=loaded_image,
            image_hashes=image_id,
            point_coordinates=args["point_coords"],
            point_labels=args["point_labels"],
            boxes=args["box"],
            mask_input=mask_input,
            multi_mask_output=multimask_output,
            threshold=MASK_THRESHOLD,
            load_from_mask_input_cache=load_logits_from_cache,
            save_to_mask_input_cache=save_logits_to_cache,
            use_embeddings_cache=True,
            return_logits=True,
        )[0]
        return choose_most_confident_sam_prediction(
            masks=prediction.masks.cpu().numpy(),
            scores=prediction.scores.cpu().numpy(),
            low_resolution_logits=prediction.logits.cpu().numpy(),
        )
Functions
__init__
__init__(
    *args,
    model_id=f"sam2/{SAM2_VERSION_ID}",
    api_key=None,
    low_res_logits_cache_size=SAM2_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size=SAM2_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs,
)

Initializes the SegmentAnything.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/sam2/segment_anything2_inference_models.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def __init__(
    self,
    *args,
    model_id: str = f"sam2/{SAM2_VERSION_ID}",
    api_key: Optional[str] = None,
    low_res_logits_cache_size: int = SAM2_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size: int = SAM2_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs,
):
    """Initializes the SegmentAnything.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__()

    self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}

    self.api_key = api_key if api_key else API_KEY

    self.task_type = "unsupervised-segmentation"

    sam2_image_embeddings_cache = Sam2ImageEmbeddingsInMemoryCache.init(
        size_limit=embedding_cache_size,
        send_to_cpu=True,
    )
    sam2_low_resolution_masks_cache = Sam2LowResolutionMasksInMemoryCache.init(
        size_limit=low_res_logits_cache_size,
        send_to_cpu=True,
    )
    extra_weights_provider_headers = get_extra_weights_provider_headers(
        countinference=kwargs.get("countinference"),
        service_secret=kwargs.get("service_secret"),
    )
    backend = list(
        VALID_INFERENCE_MODELS_BACKENDS.difference(
            DISABLED_INFERENCE_MODELS_BACKENDS
        )
    )
    self._model: SAM2Torch = AutoModel.from_pretrained(
        model_id_or_path=model_id,
        api_key=self.api_key,
        allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
        allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
        sam2_image_embeddings_cache=sam2_image_embeddings_cache,
        sam2_low_resolution_masks_cache=sam2_low_resolution_masks_cache,
        sam2_allow_client_generated_hash_ids=True,
        weights_provider_extra_headers=extra_weights_provider_headers,
        backend=backend,
        **kwargs,
    )
embed_image
embed_image(image, image_id=None, **kwargs)

Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, the cached result will be returned.

Parameters:

Name Type Description Default
image Any

The image to be embedded. The format should be compatible with the preproc_image method.

required
image_id Optional[str]

An identifier for the image. If provided, the embedding result will be cached with this ID. Defaults to None.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image and the second element is the shape (height, width) of the processed image.

Notes
  • Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Example

img_array = ... # some image array embed_image(img_array, image_id="sample123") (array([...]), (224, 224))

Source code in inference/models/sam2/segment_anything2_inference_models.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def embed_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    **kwargs,
):
    """
    Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
    the cached result will be returned.

    Args:
        image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
        image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                  with this ID. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                           and the second element is the shape (height, width) of the processed image.

    Notes:
        - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
        - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.

    Example:
        >>> img_array = ... # some image array
        >>> embed_image(img_array, image_id="sample123")
        (array([...]), (224, 224))
    """
    loaded_image = self.preproc_image(image)
    if loaded_image is None:
        raise ValueError("Image must be provided to handle this request.")
    embeddings = self._model.embed_images(
        images=loaded_image, image_hashes=image_id, **kwargs
    )[0]
    embedding_dict = {
        "image_embed": embeddings.embeddings.cpu().numpy(),
        "high_res_feats": [
            f.cpu().numpy() for f in embeddings.high_resolution_features
        ],
    }
    return embedding_dict, embeddings.image_size_hw, embeddings.image_hash
infer_from_request
infer_from_request(request)

Performs inference based on the request type.

Parameters:

Name Type Description Default
request SamInferenceRequest

The inference request.

required

Returns:

Type Description

Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.

Source code in inference/models/sam2/segment_anything2_inference_models.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@usage_collector("model")
def infer_from_request(self, request: Sam2InferenceRequest):
    """Performs inference based on the request type.

    Args:
        request (SamInferenceRequest): The inference request.

    Returns:
        Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
    """
    t1 = perf_counter()
    if isinstance(request, Sam2EmbeddingRequest):
        _, _, image_id = self.embed_image(**request.dict())
        inference_time = perf_counter() - t1
        return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
    elif isinstance(request, Sam2SegmentationRequest):
        masks, scores, low_resolution_logits = self.segment_image(**request.dict())

        if request.format == "json":
            return turn_segmentation_results_into_api_response(
                masks=masks,
                scores=scores,
                mask_threshold=MASK_THRESHOLD,
                inference_start_timestamp=t1,
            )
        elif request.format == "rle":
            return turn_segmentation_results_into_rle_response(
                masks=masks,
                scores=scores,
                mask_threshold=0.0,
                inference_start_timestamp=t1,
            )
        elif request.format == "binary":
            binary_vector = BytesIO()
            np.savez_compressed(
                binary_vector, masks=masks, low_res_masks=low_resolution_logits
            )
            binary_vector.seek(0)
            binary_data = binary_vector.getvalue()
            return binary_data
        else:
            raise ValueError(f"Invalid format {request.format}")

    else:
        raise ValueError(f"Invalid request type {type(request)}")
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/sam2/segment_anything2_inference_models.py
217
218
219
220
221
222
223
224
225
226
227
228
def preproc_image(self, image: InferenceRequestImage):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    if image is not None:
        return load_image_bgr(image)
    return None
segment_image
segment_image(
    image,
    image_id=None,
    prompts=None,
    multimask_output=True,
    mask_input=None,
    save_logits_to_cache=False,
    load_logits_from_cache=False,
    **kwargs
)

Segments an image based on provided embeddings, points, masks, or cached results. If embeddings are not directly provided, the function can derive them from the input image or cache.

Parameters:

Name Type Description Default
image Any

The image to be segmented.

required
image_id Optional[str]

A cached identifier for the image. Useful for accessing cached embeddings or masks.

None
prompts Optional[List[Sam2Prompt]]

List of prompts to use for segmentation. Defaults to None.

None
mask_input Optional[Union[ndarray, List[List[List[float]]]]]

Input low_res_logits for the image.

None
multimask_output Optional[bool]

(bool): Flag to decide if multiple masks proposal to be predicted (among which the most promising will be returned

True
use_logits_cache

(bool): Flag to decide to use cached logits from prior prompting

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Raises:

Type Description
ValueError

If necessary inputs are missing or inconsistent.

Notes
  • Embeddings, segmentations, and low-resolution logits can be cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Source code in inference/models/sam2/segment_anything2_inference_models.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def segment_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    prompts: Optional[Union[Sam2PromptSet, dict]] = None,
    multimask_output: Optional[bool] = True,
    mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
    save_logits_to_cache: bool = False,
    load_logits_from_cache: bool = False,
    **kwargs,
):
    """
    Segments an image based on provided embeddings, points, masks, or cached results.
    If embeddings are not directly provided, the function can derive them from the input image or cache.

    Args:
        image (Any): The image to be segmented.
        image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
        prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
        mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
        multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
            promising will be returned
        )
        use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element

    Raises:
        ValueError: If necessary inputs are missing or inconsistent.

    Notes:
        - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
          on repeated requests for the same image.
        - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.
    """
    load_logits_from_cache = (
        load_logits_from_cache and not DISABLE_SAM2_LOGITS_CACHE
    )
    save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM2_LOGITS_CACHE
    loaded_image = self.preproc_image(image)
    if prompts is not None:
        if type(prompts) is dict:
            prompts = Sam2PromptSet(**prompts)
    else:
        prompts = Sam2PromptSet()
    args = prompts.to_sam2_inputs()
    args = pad_points(args)
    if not any(args.values()):
        args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}
    if args["point_coords"] is not None:
        args["point_coords"] = np.array(args["point_coords"])
    if args["point_labels"] is not None:
        args["point_labels"] = np.array(args["point_labels"])
    if args["box"] is not None:
        args["box"] = np.array(args["box"])
    if mask_input is not None and isinstance(mask_input, list):
        mask_input = np.array(mask_input)
    prediction = self._model.segment_images(
        images=loaded_image,
        image_hashes=image_id,
        point_coordinates=args["point_coords"],
        point_labels=args["point_labels"],
        boxes=args["box"],
        mask_input=mask_input,
        multi_mask_output=multimask_output,
        threshold=MASK_THRESHOLD,
        load_from_mask_input_cache=load_logits_from_cache,
        save_to_mask_input_cache=save_logits_to_cache,
        use_embeddings_cache=True,
        return_logits=True,
    )[0]
    return choose_most_confident_sam_prediction(
        masks=prediction.masks.cpu().numpy(),
        scores=prediction.scores.cpu().numpy(),
        low_resolution_logits=prediction.logits.cpu().numpy(),
    )

Functions

choose_most_confident_sam_prediction

choose_most_confident_sam_prediction(
    masks, scores, low_resolution_logits
)

This function is supposed to post-process SAM2 inference and choose most confident mask regardless of multimask_output parameter value Args: masks: np array with values 0.0 and 1.0 representing predicted mask of size (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation in SAM2 library, so to handle inference uniformly, we need to compensate with this function. scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending on prompt set size - this array gives confidence score for mask proposal low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits can be passed to a subsequent iteration as mask input. Returns: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Source code in inference/models/sam2/segment_anything2_inference_models.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def choose_most_confident_sam_prediction(
    masks: np.ndarray,
    scores: np.ndarray,
    low_resolution_logits: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function is supposed to post-process SAM2 inference and choose most confident
    mask regardless of `multimask_output` parameter value
    Args:
        masks: np array with values 0.0 and 1.0 representing predicted mask of size
            (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on
            prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation
            in SAM2 library, so to handle inference uniformly, we need to compensate with
            this function.
        scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending
            on prompt set size - this array gives confidence score for mask proposal
        low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or
            (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits
             can be passed to a subsequent iteration as mask input.
    Returns:
        Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element
    """
    if len(masks.shape) == 3:
        masks = np.expand_dims(masks, axis=0)
        scores = np.expand_dims(scores, axis=0)
        low_resolution_logits = np.expand_dims(low_resolution_logits, axis=0)
    selected_masks, selected_scores, selected_low_resolution_logits = [], [], []
    for mask, score, low_resolution_logit in zip(masks, scores, low_resolution_logits):
        selected_mask, selected_score, selected_low_resolution_logit = (
            choose_most_confident_prompt_set_element_prediction(
                mask=mask,
                score=score,
                low_resolution_logit=low_resolution_logit,
            )
        )
        selected_masks.append(selected_mask)
        selected_scores.append(selected_score)
        selected_low_resolution_logits.append(selected_low_resolution_logit)
    return (
        np.asarray(selected_masks),
        np.asarray(selected_scores),
        np.asarray(selected_low_resolution_logits),
    )

pad_points

pad_points(args)

Pad arguments to be passed to sam2 model with not_a_point label (-1). This is necessary when there are multiple prompts per image so that a tensor can be created.

Also pads empty point lists with a dummy non-point entry.

Source code in inference/models/sam2/segment_anything2_inference_models.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def pad_points(args: Dict[str, Any]) -> Dict[str, Any]:
    """
    Pad arguments to be passed to sam2 model with not_a_point label (-1).
    This is necessary when there are multiple prompts per image so that a tensor can be created.


    Also pads empty point lists with a dummy non-point entry.
    """
    args = copy.deepcopy(args)
    if args["point_coords"] is not None:
        max_len = max(max(len(prompt) for prompt in args["point_coords"]), 1)
        for prompt in args["point_coords"]:
            for _ in range(max_len - len(prompt)):
                prompt.append([0, 0])
        for label in args["point_labels"]:
            for _ in range(max_len - len(label)):
                label.append(-1)
    else:
        if args["point_labels"] is not None:
            raise ValueError(
                "Can't have point labels without corresponding point coordinates"
            )
    return args

models/sam3

inference.models.sam3.segment_anything3

Classes

SegmentAnything3

Bases: RoboflowCoreModel

SAM3 wrapper with a similar interface to SAM2 in this codebase.

Source code in inference/models/sam3/segment_anything3.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
class SegmentAnything3(RoboflowCoreModel):
    """SAM3 wrapper with a similar interface to SAM2 in this codebase."""

    def __init__(
        self,
        *args,
        model_id: str = "sam3/sam3_final",
        **kwargs,
    ):
        super().__init__(*args, model_id=model_id, **kwargs)

        # Lazy import SAM3 to avoid hard dependency when disabled
        from sam3 import build_sam3_image_model

        checkpoint = self.cache_file("weights.pt")
        bpe_path = self.cache_file("bpe_simple_vocab_16e6.txt.gz")

        self.sam3_lock = threading.RLock()

        self.model = build_sam3_image_model(
            bpe_path=bpe_path,
            checkpoint_path=checkpoint,
            device="cuda" if torch.cuda.is_available() else "cpu",
            load_from_HF=False,
            compile=False,
        )

        # Preprocessing and postprocessing for PCS image path
        self.transform = ComposeAPI(
            transforms=[
                RandomResizeAPI(
                    sizes=SAM3_IMAGE_SIZE,
                    max_size=SAM3_IMAGE_SIZE,
                    square=True,
                    consistent_transform=False,
                ),
                ToTensorAPI(),
                NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

        self.image_size = SAM3_IMAGE_SIZE
        self.task_type = "unsupervised-segmentation"

    def _is_core_sam3_endpoint(self) -> bool:
        return isinstance(self.endpoint, str) and self.endpoint.startswith("sam3/")

    @property
    def model_artifact_bucket(self):
        # Use CORE bucket for base SAM3, standard INFER bucket for fine-tuned models
        return CORE_MODEL_BUCKET if self._is_core_sam3_endpoint() else INFER_BUCKET

    def download_weights(self) -> None:
        infer_bucket_files = self.get_infer_bucket_file_list()

        # Auth check aligned with chosen endpoint type
        if MODELS_CACHE_AUTH_ENABLED:
            endpoint_type = (
                ModelEndpointType.CORE_MODEL
                if self._is_core_sam3_endpoint()
                else ModelEndpointType.ORT
            )
            if not _check_if_api_key_has_access_to_model(
                api_key=self.api_key,
                model_id=self.endpoint,
                endpoint_type=endpoint_type,
            ):
                raise RoboflowAPINotAuthorizedError(
                    f"API key {self.api_key} does not have access to model {self.endpoint}"
                )

        # Already cached
        if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint):
            return None

        # S3 path works for both; keys are {endpoint}/<file>
        if is_model_artefacts_bucket_available():
            self.download_model_artefacts_from_s3()
            return None

        # API fallback
        if self._is_core_sam3_endpoint():
            # Base SAM3 from core_model endpoint; preserves filenames
            return super().download_model_from_roboflow_api()

        # Fine-tuned SAM3: use ORT endpoint to fetch weights map or model url
        api_data = get_roboflow_model_data(
            api_key=self.api_key,
            model_id=self.endpoint,
            endpoint_type=ModelEndpointType.ORT,
            device_id=self.device_id,
        )

        ort = api_data.get("ort") if isinstance(api_data, dict) else None
        if not isinstance(ort, dict):
            raise ModelArtefactError("ORT response malformed for fine-tuned SAM3")

        # Preferred: explicit weights map of filename -> URL
        weights_map = ort.get("weights")
        if isinstance(weights_map, dict) and len(weights_map) > 0:
            for filename, url in weights_map.items():
                resp = get_from_url(url, json_response=False)
                save_bytes_in_cache(
                    content=resp.content,
                    file=str(filename),
                    model_id=self.endpoint,
                )
            return None

        raise ModelArtefactError(
            "ORT response missing both 'weights' for fine-tuned SAM3"
        )

    def get_infer_bucket_file_list(self) -> List[str]:
        # SAM3 weights managed by env; no core bucket artifacts

        return [
            "weights.pt",
            "bpe_simple_vocab_16e6.txt.gz",
        ]

    def preproc_image(self, image: InferenceRequestImage) -> np.ndarray:
        np_image = load_image_rgb(image)
        return np_image

    @usage_collector("model")
    def infer_from_request(self, request: Sam3InferenceRequest):
        # with self.sam3_lock:
        t1 = perf_counter()
        if isinstance(request, Sam3SegmentationRequest):
            # Pass strongly-typed fields to preserve Sam3Prompt objects
            result = self.segment_image(
                image=request.image,
                image_id=request.image_id,
                prompts=request.prompts,
                output_prob_thresh=request.output_prob_thresh or 0.5,
                format=request.format or "polygon",
                nms_iou_threshold=request.nms_iou_threshold,
            )
            # segment_image now returns either bytes or a response model
            return result
        else:
            raise ValueError(f"Invalid request type {type(request)}")

    def segment_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        prompts: Optional[List[Sam3Prompt]] = None,
        output_prob_thresh: float = 0.5,
        format: Optional[str] = "polygon",
        nms_iou_threshold: Optional[float] = None,
        **kwargs,
    ):
        np_image = load_image_rgb(image)
        h, w = np_image.shape[:2]
        pil_image = Image.fromarray(np_image)

        # Inference-only path; disable autograd throughout
        with torch.inference_mode():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                start_ts = perf_counter()

                # TODO this can also take tensor directly instead of PIL image, so we want to avoid double conversion
                # TODO: this also supports multiple images for multi batch inference
                datapoint = Sam3Datapoint(
                    find_queries=[],
                    images=[Sam3ImageDP(data=pil_image, objects=[], size=(h, w))],
                )

                # Build prompts in order
                prompts = prompts or []

                # Map prompt_index -> prompt_id to retrieve results later
                prompt_ids: List[int] = []
                for idx, p in enumerate(prompts):
                    if getattr(p, "boxes", None):
                        q = _build_visual_query(
                            coco_id=idx,
                            h=h,
                            w=w,
                            boxes=p.boxes,
                            labels=p.box_labels or [],
                            text=p.text,
                        )
                    else:
                        q = _build_text_query(
                            coco_id=idx,
                            h=h,
                            w=w,
                            text=p.text,
                        )
                    datapoint.find_queries.append(q)
                    prompt_ids.append(idx)

                # Transform and collate to BatchedDatapoint
                datapoint = self.transform(datapoint)
                batch = collate_fn_api(batch=[datapoint], dict_key="dummy")["dummy"]
                batch = copy_data_to_device(
                    batch,
                    torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                    non_blocking=True,
                )

                # Forward
                output = self.model(batch)

                # Calculate minimum threshold for initial filtering
                # (we'll apply per-prompt thresholds later)
                min_threshold = output_prob_thresh
                for p in prompts:
                    prompt_thresh = getattr(p, "output_prob_thresh", None)
                    if prompt_thresh is not None:
                        min_threshold = min(min_threshold, prompt_thresh)

                # Postprocess to original size and build per-prompt results
                post = PostProcessImage(
                    max_dets_per_img=-1,
                    iou_type="segm",
                    use_original_sizes_box=True,
                    use_original_sizes_mask=True,
                    convert_mask_to_rle=False,
                    detection_threshold=float(
                        min_threshold if min_threshold is not None else 0.35
                    ),
                    to_cpu=True,
                )
                processed = post.process_results(output, batch.find_metadatas)

        needs_cross_prompt_nms = nms_iou_threshold is not None
        prompt_results: List[Sam3PromptResult] = []

        if needs_cross_prompt_nms and len(prompts) > 0:
            all_masks = _collect_masks_with_per_prompt_threshold(
                processed=processed,
                prompt_ids=prompt_ids,
                prompts=prompts,
                default_threshold=output_prob_thresh,
            )

            if len(all_masks) > 0:
                all_masks = _apply_nms_cross_prompt(all_masks, nms_iou_threshold)

            regrouped = _regroup_masks_by_prompt(all_masks, len(prompts))

            # Build prompt results from regrouped masks
            for idx, coco_id in enumerate(prompt_ids):
                has_visual = bool(getattr(prompts[idx], "boxes", None))
                num_boxes = len(prompts[idx].boxes or []) if has_visual else 0
                echo = Sam3PromptEcho(
                    prompt_index=idx,
                    type=("visual" if has_visual else "text"),
                    text=prompts[idx].text,
                    num_boxes=num_boxes,
                )

                # Convert regrouped masks to predictions
                prompt_masks = regrouped.get(idx, [])
                if prompt_masks:
                    masks_np = np.stack([m for m, _ in prompt_masks], axis=0)
                    scores = [s for _, s in prompt_masks]
                else:
                    masks_np = np.zeros((0, 0, 0), dtype=np.uint8)
                    scores = []

                preds = _masks_to_predictions(masks_np, scores, format)
                prompt_results.append(
                    Sam3PromptResult(prompt_index=idx, echo=echo, predictions=preds)
                )
        else:
            for idx, coco_id in enumerate(prompt_ids):
                has_visual = bool(getattr(prompts[idx], "boxes", None))
                num_boxes = len(prompts[idx].boxes or []) if has_visual else 0
                echo = Sam3PromptEcho(
                    prompt_index=idx,
                    type=("visual" if has_visual else "text"),
                    text=prompts[idx].text,
                    num_boxes=num_boxes,
                )
                masks_np = _to_numpy_masks(processed[coco_id].get("masks"))
                scores = list(processed[coco_id].get("scores", []))
                prompt_thresh = getattr(prompts[idx], "output_prob_thresh", None)
                if prompt_thresh is not None:
                    masks_np, scores = _filter_by_threshold(
                        masks_np, scores, prompt_thresh
                    )
                preds = _masks_to_predictions(masks_np, scores, format)
                prompt_results.append(
                    Sam3PromptResult(prompt_index=idx, echo=echo, predictions=preds)
                )

        return Sam3SegmentationResponse(
            time=perf_counter() - start_ts, prompt_results=prompt_results
        )

Functions

inference.models.sam3.visual_segmentation

Classes

Sam3ForInteractiveImageSegmentation

Bases: RoboflowCoreModel

SegmentAnything3 class for handling segmentation tasks onm images with box prompting and point prompting, the way as SAM2 did.

Source code in inference/models/sam3/visual_segmentation.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
class Sam3ForInteractiveImageSegmentation(RoboflowCoreModel):
    """
    SegmentAnything3 class for handling segmentation tasks onm images with
    box prompting and point prompting, the way as SAM2 did.
    """

    def __init__(
        self,
        *args,
        model_id: str = "sam3/sam3_final",
        low_res_logits_cache_size: int = SAM3_MAX_LOGITS_CACHE_SIZE,
        embedding_cache_size: int = SAM3_MAX_EMBEDDING_CACHE_SIZE,
        **kwargs,
    ):
        """Initializes the SegmentAnything.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, model_id=model_id, **kwargs)
        checkpoint = self.cache_file("weights.pt")
        bpe_path = self.cache_file("bpe_simple_vocab_16e6.txt.gz")

        self.sam_model = build_sam3_image_model(
            bpe_path=bpe_path,
            checkpoint_path=checkpoint,
            device="cuda" if torch.cuda.is_available() else "cpu",
            load_from_HF=False,
            compile=False,
            enable_inst_interactivity=True,
        )
        self.low_res_logits_cache_size = low_res_logits_cache_size
        self.embedding_cache_size = embedding_cache_size
        self.embedding_cache = {}
        self.image_size_cache = {}
        self.embedding_cache_keys = []
        self.low_res_logits_cache: Dict[Tuple[str, str], LogitsCacheType] = {}
        self.low_res_logits_cache_keys = []
        self._state_lock = RLock()
        self.task_type = "unsupervised-segmentation"

    def get_infer_bucket_file_list(self) -> List[str]:
        """Gets the list of files required for inference.

        Returns:
            List[str]: List of file names.
        """
        return ["weights.pt"]

    @torch.inference_mode()
    def embed_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        **kwargs,
    ):
        """
        Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
        the cached result will be returned.

        Args:
            image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
            image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                      with this ID. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                               and the second element is the shape (height, width) of the processed image.

        Notes:
            - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
            - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.

        Example:
            >>> img_array = ... # some image array
            >>> embed_image(img_array, image_id="sample123")
            (array([...]), (224, 224))
        """
        if image_id:
            embedding_cache_content = self.embedding_cache.get(image_id)
            image_size_content = self.image_size_cache.get(image_id)
            if embedding_cache_content is not None and image_size_content is not None:
                return embedding_cache_content, image_size_content, image_id

        img_in = self.preproc_image(image)
        if image_id is None:
            image_id = hashlib.md5(img_in.tobytes()).hexdigest()[:12]

        embedding_cache_content = self.embedding_cache.get(image_id)
        image_size_content = self.image_size_cache.get(image_id)
        if embedding_cache_content is not None and image_size_content is not None:
            return (
                embedding_cache_content,
                image_size_content,
                image_id,
            )

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            with _temporarily_disable_torch_jit_script():
                processor = Sam3Processor(self.sam_model)
            state = processor.set_image(torch.from_numpy(img_in).permute(2, 0, 1))
            embedding_dict = state

        with self._state_lock:
            self.embedding_cache[image_id] = embedding_dict
            self.image_size_cache[image_id] = img_in.shape[:2]
            safe_remove_from_list(values=self.embedding_cache_keys, element=image_id)
            self.embedding_cache_keys.append(image_id)
            if len(self.embedding_cache_keys) > self.embedding_cache_size:
                cache_key = safe_pop_from_list(values=self.embedding_cache_keys)
                if cache_key is not None:
                    safe_remove_from_dict(values=self.embedding_cache, key=cache_key)
                    safe_remove_from_dict(values=self.image_size_cache, key=cache_key)
            return embedding_dict, img_in.shape[:2], image_id

    @usage_collector("model")
    def infer_from_request(self, request: Sam2InferenceRequest):
        """Performs inference based on the request type.

        Args:
            request (SamInferenceRequest): The inference request.

        Returns:
            Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
        """
        t1 = perf_counter()
        if isinstance(request, Sam2EmbeddingRequest):
            _, _, image_id = self.embed_image(**request.dict())
            inference_time = perf_counter() - t1
            return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
        elif isinstance(request, Sam2SegmentationRequest):
            masks, scores, low_resolution_logits = self.segment_image(**request.dict())
            predictions = _masks_to_predictions(masks, scores, request.format)
            return Sam2SegmentationResponse(
                time=perf_counter() - t1,
                predictions=predictions,
            )
        else:
            raise ValueError(f"Invalid request type {type(request)}")

    def preproc_image(self, image: InferenceRequestImage):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image

    def segment_image(
        self,
        image: Optional[InferenceRequestImage],
        image_id: Optional[str] = None,
        prompts: Optional[Union[Sam2PromptSet, dict]] = None,
        multimask_output: Optional[bool] = True,
        mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
        save_logits_to_cache: bool = False,
        load_logits_from_cache: bool = False,
        **kwargs,
    ):
        """
        Segments an image based on provided embeddings, points, masks, or cached results.
        If embeddings are not directly provided, the function can derive them from the input image or cache.

        Args:
            image (Any): The image to be segmented.
            image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
            prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
            mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
            multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
                promising will be returned
            )
            use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
                - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                    for each prompt element
                - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                    of each prompt element
                - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                    for most confident mask of each prompt element

        Raises:
            ValueError: If necessary inputs are missing or inconsistent.

        Notes:
            - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
              on repeated requests for the same image.
            - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
              the oldest entries are removed.
        """
        load_logits_from_cache = (
            load_logits_from_cache and not DISABLE_SAM3_LOGITS_CACHE
        )
        save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM3_LOGITS_CACHE
        with torch.inference_mode():
            if image is None and not image_id:
                raise ValueError("Must provide either image or  cached image_id")
            elif image_id and image is None and image_id not in self.embedding_cache:
                raise ValueError(
                    f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
                )
            embedding, original_image_size, image_id = self.embed_image(
                image=image, image_id=image_id
            )

            # with _temporarily_disable_torch_jit_script():
            # processor = Sam3Processor(self.sam_model)

            # processor._is_image_set = True
            # processor._features = embedding
            # processor._orig_hw = [original_image_size]
            # processor._is_batch = False
            args = dict()
            prompt_set: Sam2PromptSet
            if prompts:
                if type(prompts) is dict:
                    prompt_set = Sam2PromptSet(**prompts)
                    args = prompt_set.to_sam2_inputs()
                else:
                    prompt_set = prompts
                    args = prompts.to_sam2_inputs()
            else:
                prompt_set = Sam2PromptSet()

            if mask_input is None and load_logits_from_cache:
                mask_input = maybe_load_low_res_logits_from_cache(
                    image_id, prompt_set, self.low_res_logits_cache
                )

            args = pad_points(args)
            if not any(args.values()):
                args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}

            masks, scores, low_resolution_logits = self.sam_model.predict_inst(
                embedding,
                mask_input=mask_input,
                multimask_output=multimask_output,
                return_logits=True,
                normalize_coords=True,
                **args,
            )
            masks, scores, low_resolution_logits = choose_most_confident_sam_prediction(
                masks=masks,
                scores=scores,
                low_resolution_logits=low_resolution_logits,
            )

            if save_logits_to_cache:
                self.add_low_res_logits_to_cache(
                    low_resolution_logits, image_id, prompt_set
                )

            return masks, scores, low_resolution_logits

    def add_low_res_logits_to_cache(
        self, logits: np.ndarray, image_id: str, prompt_set: Sam2PromptSet
    ) -> None:
        logits = logits[:, None, :, :]
        prompt_id = hash_prompt_set(image_id, prompt_set)
        with self._state_lock:
            self.low_res_logits_cache[prompt_id] = {
                "logits": logits,
                "prompt_set": prompt_set,
            }
            safe_remove_from_list(
                values=self.low_res_logits_cache_keys, element=prompt_id
            )
            self.low_res_logits_cache_keys.append(prompt_id)
            if len(self.low_res_logits_cache_keys) > self.low_res_logits_cache_size:
                cache_key = safe_pop_from_list(values=self.low_res_logits_cache_keys)
                if cache_key is not None:
                    safe_remove_from_dict(
                        values=self.low_res_logits_cache, key=cache_key
                    )

    @property
    def model_artifact_bucket(self):
        # Use CORE bucket for base SAM3, standard INFER bucket for fine-tuned models
        return CORE_MODEL_BUCKET if self._is_core_sam3_endpoint() else INFER_BUCKET

    def _is_core_sam3_endpoint(self) -> bool:
        return isinstance(self.endpoint, str) and self.endpoint.startswith("sam3/")

    def download_weights(self) -> None:
        infer_bucket_files = self.get_infer_bucket_file_list()

        # Auth check aligned with chosen endpoint type
        if MODELS_CACHE_AUTH_ENABLED:
            endpoint_type = (
                ModelEndpointType.CORE_MODEL
                if self._is_core_sam3_endpoint()
                else ModelEndpointType.ORT
            )
            if not _check_if_api_key_has_access_to_model(
                api_key=self.api_key,
                model_id=self.endpoint,
                endpoint_type=endpoint_type,
            ):
                raise RoboflowAPINotAuthorizedError(
                    f"API key {self.api_key} does not have access to model {self.endpoint}"
                )
        # Already cached
        if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint):
            return None
        # S3 path works for both; keys are {endpoint}/<file>
        if is_model_artefacts_bucket_available():
            self.download_model_artefacts_from_s3()
            return None
            # API fallback
        if self._is_core_sam3_endpoint():
            # Base SAM3 from core_model endpoint; preserves filenames
            return super().download_model_from_roboflow_api()

        # Fine-tuned SAM3: use ORT endpoint to fetch weights map or model url
        api_data = get_roboflow_model_data(
            api_key=self.api_key,
            model_id=self.endpoint,
            endpoint_type=ModelEndpointType.ORT,
            device_id=self.device_id,
        )

        ort = api_data.get("ort") if isinstance(api_data, dict) else None
        if not isinstance(ort, dict):
            raise ModelArtefactError("ORT response malformed for fine-tuned SAM3")

        # Preferred: explicit weights map of filename -> URL
        weights_map = ort.get("weights")
        if isinstance(weights_map, dict) and len(weights_map) > 0:
            for filename, url in weights_map.items():
                resp = get_from_url(url, json_response=False)
                save_bytes_in_cache(
                    content=resp.content,
                    file=str(filename),
                    model_id=self.endpoint,
                )
            return None
        raise ModelArtefactError(
            "ORT response missing both 'weights' for fine-tuned SAM3"
        )
Functions
__init__
__init__(
    *args,
    model_id="sam3/sam3_final",
    low_res_logits_cache_size=SAM3_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size=SAM3_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs
)

Initializes the SegmentAnything.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/sam3/visual_segmentation.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    *args,
    model_id: str = "sam3/sam3_final",
    low_res_logits_cache_size: int = SAM3_MAX_LOGITS_CACHE_SIZE,
    embedding_cache_size: int = SAM3_MAX_EMBEDDING_CACHE_SIZE,
    **kwargs,
):
    """Initializes the SegmentAnything.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, model_id=model_id, **kwargs)
    checkpoint = self.cache_file("weights.pt")
    bpe_path = self.cache_file("bpe_simple_vocab_16e6.txt.gz")

    self.sam_model = build_sam3_image_model(
        bpe_path=bpe_path,
        checkpoint_path=checkpoint,
        device="cuda" if torch.cuda.is_available() else "cpu",
        load_from_HF=False,
        compile=False,
        enable_inst_interactivity=True,
    )
    self.low_res_logits_cache_size = low_res_logits_cache_size
    self.embedding_cache_size = embedding_cache_size
    self.embedding_cache = {}
    self.image_size_cache = {}
    self.embedding_cache_keys = []
    self.low_res_logits_cache: Dict[Tuple[str, str], LogitsCacheType] = {}
    self.low_res_logits_cache_keys = []
    self._state_lock = RLock()
    self.task_type = "unsupervised-segmentation"
embed_image
embed_image(image, image_id=None, **kwargs)

Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, the cached result will be returned.

Parameters:

Name Type Description Default
image Any

The image to be embedded. The format should be compatible with the preproc_image method.

required
image_id Optional[str]

An identifier for the image. If provided, the embedding result will be cached with this ID. Defaults to None.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image and the second element is the shape (height, width) of the processed image.

Notes
  • Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Example

img_array = ... # some image array embed_image(img_array, image_id="sample123") (array([...]), (224, 224))

Source code in inference/models/sam3/visual_segmentation.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@torch.inference_mode()
def embed_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    **kwargs,
):
    """
    Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
    the cached result will be returned.

    Args:
        image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
        image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
                                  with this ID. Defaults to None.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
                                           and the second element is the shape (height, width) of the processed image.

    Notes:
        - Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
        - The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.

    Example:
        >>> img_array = ... # some image array
        >>> embed_image(img_array, image_id="sample123")
        (array([...]), (224, 224))
    """
    if image_id:
        embedding_cache_content = self.embedding_cache.get(image_id)
        image_size_content = self.image_size_cache.get(image_id)
        if embedding_cache_content is not None and image_size_content is not None:
            return embedding_cache_content, image_size_content, image_id

    img_in = self.preproc_image(image)
    if image_id is None:
        image_id = hashlib.md5(img_in.tobytes()).hexdigest()[:12]

    embedding_cache_content = self.embedding_cache.get(image_id)
    image_size_content = self.image_size_cache.get(image_id)
    if embedding_cache_content is not None and image_size_content is not None:
        return (
            embedding_cache_content,
            image_size_content,
            image_id,
        )

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        with _temporarily_disable_torch_jit_script():
            processor = Sam3Processor(self.sam_model)
        state = processor.set_image(torch.from_numpy(img_in).permute(2, 0, 1))
        embedding_dict = state

    with self._state_lock:
        self.embedding_cache[image_id] = embedding_dict
        self.image_size_cache[image_id] = img_in.shape[:2]
        safe_remove_from_list(values=self.embedding_cache_keys, element=image_id)
        self.embedding_cache_keys.append(image_id)
        if len(self.embedding_cache_keys) > self.embedding_cache_size:
            cache_key = safe_pop_from_list(values=self.embedding_cache_keys)
            if cache_key is not None:
                safe_remove_from_dict(values=self.embedding_cache, key=cache_key)
                safe_remove_from_dict(values=self.image_size_cache, key=cache_key)
        return embedding_dict, img_in.shape[:2], image_id
get_infer_bucket_file_list
get_infer_bucket_file_list()

Gets the list of files required for inference.

Returns:

Type Description
List[str]

List[str]: List of file names.

Source code in inference/models/sam3/visual_segmentation.py
114
115
116
117
118
119
120
def get_infer_bucket_file_list(self) -> List[str]:
    """Gets the list of files required for inference.

    Returns:
        List[str]: List of file names.
    """
    return ["weights.pt"]
infer_from_request
infer_from_request(request)

Performs inference based on the request type.

Parameters:

Name Type Description Default
request SamInferenceRequest

The inference request.

required

Returns:

Type Description

Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.

Source code in inference/models/sam3/visual_segmentation.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@usage_collector("model")
def infer_from_request(self, request: Sam2InferenceRequest):
    """Performs inference based on the request type.

    Args:
        request (SamInferenceRequest): The inference request.

    Returns:
        Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
    """
    t1 = perf_counter()
    if isinstance(request, Sam2EmbeddingRequest):
        _, _, image_id = self.embed_image(**request.dict())
        inference_time = perf_counter() - t1
        return Sam2EmbeddingResponse(time=inference_time, image_id=image_id)
    elif isinstance(request, Sam2SegmentationRequest):
        masks, scores, low_resolution_logits = self.segment_image(**request.dict())
        predictions = _masks_to_predictions(masks, scores, request.format)
        return Sam2SegmentationResponse(
            time=perf_counter() - t1,
            predictions=predictions,
        )
    else:
        raise ValueError(f"Invalid request type {type(request)}")
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/sam3/visual_segmentation.py
215
216
217
218
219
220
221
222
223
224
225
def preproc_image(self, image: InferenceRequestImage):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    np_image = load_image_rgb(image)
    return np_image
segment_image
segment_image(
    image,
    image_id=None,
    prompts=None,
    multimask_output=True,
    mask_input=None,
    save_logits_to_cache=False,
    load_logits_from_cache=False,
    **kwargs
)

Segments an image based on provided embeddings, points, masks, or cached results. If embeddings are not directly provided, the function can derive them from the input image or cache.

Parameters:

Name Type Description Default
image Any

The image to be segmented.

required
image_id Optional[str]

A cached identifier for the image. Useful for accessing cached embeddings or masks.

None
prompts Optional[List[Sam2Prompt]]

List of prompts to use for segmentation. Defaults to None.

None
mask_input Optional[Union[ndarray, List[List[List[float]]]]]

Input low_res_logits for the image.

None
multimask_output Optional[bool]

(bool): Flag to decide if multiple masks proposal to be predicted (among which the most promising will be returned

True
use_logits_cache

(bool): Flag to decide to use cached logits from prior prompting

required
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description

Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Raises:

Type Description
ValueError

If necessary inputs are missing or inconsistent.

Notes
  • Embeddings, segmentations, and low-resolution logits can be cached to improve performance on repeated requests for the same image.
  • The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, the oldest entries are removed.
Source code in inference/models/sam3/visual_segmentation.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def segment_image(
    self,
    image: Optional[InferenceRequestImage],
    image_id: Optional[str] = None,
    prompts: Optional[Union[Sam2PromptSet, dict]] = None,
    multimask_output: Optional[bool] = True,
    mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
    save_logits_to_cache: bool = False,
    load_logits_from_cache: bool = False,
    **kwargs,
):
    """
    Segments an image based on provided embeddings, points, masks, or cached results.
    If embeddings are not directly provided, the function can derive them from the input image or cache.

    Args:
        image (Any): The image to be segmented.
        image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
        prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
        mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
        multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
            promising will be returned
        )
        use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element

    Raises:
        ValueError: If necessary inputs are missing or inconsistent.

    Notes:
        - Embeddings, segmentations, and low-resolution logits can be cached to improve performance
          on repeated requests for the same image.
        - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
          the oldest entries are removed.
    """
    load_logits_from_cache = (
        load_logits_from_cache and not DISABLE_SAM3_LOGITS_CACHE
    )
    save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM3_LOGITS_CACHE
    with torch.inference_mode():
        if image is None and not image_id:
            raise ValueError("Must provide either image or  cached image_id")
        elif image_id and image is None and image_id not in self.embedding_cache:
            raise ValueError(
                f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
            )
        embedding, original_image_size, image_id = self.embed_image(
            image=image, image_id=image_id
        )

        # with _temporarily_disable_torch_jit_script():
        # processor = Sam3Processor(self.sam_model)

        # processor._is_image_set = True
        # processor._features = embedding
        # processor._orig_hw = [original_image_size]
        # processor._is_batch = False
        args = dict()
        prompt_set: Sam2PromptSet
        if prompts:
            if type(prompts) is dict:
                prompt_set = Sam2PromptSet(**prompts)
                args = prompt_set.to_sam2_inputs()
            else:
                prompt_set = prompts
                args = prompts.to_sam2_inputs()
        else:
            prompt_set = Sam2PromptSet()

        if mask_input is None and load_logits_from_cache:
            mask_input = maybe_load_low_res_logits_from_cache(
                image_id, prompt_set, self.low_res_logits_cache
            )

        args = pad_points(args)
        if not any(args.values()):
            args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}

        masks, scores, low_resolution_logits = self.sam_model.predict_inst(
            embedding,
            mask_input=mask_input,
            multimask_output=multimask_output,
            return_logits=True,
            normalize_coords=True,
            **args,
        )
        masks, scores, low_resolution_logits = choose_most_confident_sam_prediction(
            masks=masks,
            scores=scores,
            low_resolution_logits=low_resolution_logits,
        )

        if save_logits_to_cache:
            self.add_low_res_logits_to_cache(
                low_resolution_logits, image_id, prompt_set
            )

        return masks, scores, low_resolution_logits

Functions

choose_most_confident_sam_prediction

choose_most_confident_sam_prediction(
    masks, scores, low_resolution_logits
)

This function is supposed to post-process SAM2 inference and choose most confident mask regardless of multimask_output parameter value Args: masks: np array with values 0.0 and 1.0 representing predicted mask of size (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation in SAM2 library, so to handle inference uniformly, we need to compensate with this function. scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending on prompt set size - this array gives confidence score for mask proposal low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits can be passed to a subsequent iteration as mask input. Returns: Tuple of np.array, where: - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence for each prompt element - second element is of size (prompt_set_size, ) and represents ths score for most confident mask of each prompt element - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits for most confident mask of each prompt element

Source code in inference/models/sam3/visual_segmentation.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def choose_most_confident_sam_prediction(
    masks: np.ndarray,
    scores: np.ndarray,
    low_resolution_logits: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    This function is supposed to post-process SAM2 inference and choose most confident
    mask regardless of `multimask_output` parameter value
    Args:
        masks: np array with values 0.0 and 1.0 representing predicted mask of size
            (prompt_set_size, proposed_maks, h, w) or (proposed_maks, h, w) - depending on
            prompt set size - unfortunately, prompt_set_size=1 causes squeeze operation
            in SAM2 library, so to handle inference uniformly, we need to compensate with
            this function.
        scores: array of size (prompt_set_size, proposed_maks) or (proposed_maks, ) depending
            on prompt set size - this array gives confidence score for mask proposal
        low_resolution_logits: array of size (prompt_set_size, proposed_maks, 256, 256) or
            (proposed_maks, 256, 256) - depending on prompt set size. These low resolution logits
             can be passed to a subsequent iteration as mask input.
    Returns:
        Tuple of np.array, where:
            - first element is of size (prompt_set_size, h, w) and represent mask with the highest confidence
                for each prompt element
            - second element is of size (prompt_set_size, ) and represents ths score for most confident mask
                of each prompt element
            - third element is of size (prompt_set_size, 256, 256) and represents the low resolution logits
                for most confident mask of each prompt element
    """
    if len(masks.shape) == 3:
        masks = np.expand_dims(masks, axis=0)
        scores = np.expand_dims(scores, axis=0)
        low_resolution_logits = np.expand_dims(low_resolution_logits, axis=0)
    selected_masks, selected_scores, selected_low_resolution_logits = [], [], []
    for mask, score, low_resolution_logit in zip(masks, scores, low_resolution_logits):
        selected_mask, selected_score, selected_low_resolution_logit = (
            choose_most_confident_prompt_set_element_prediction(
                mask=mask,
                score=score,
                low_resolution_logit=low_resolution_logit,
            )
        )
        selected_masks.append(selected_mask)
        selected_scores.append(selected_score)
        selected_low_resolution_logits.append(selected_low_resolution_logit)
    return (
        np.asarray(selected_masks),
        np.asarray(selected_scores),
        np.asarray(selected_low_resolution_logits),
    )

find_prior_prompt_in_cache

find_prior_prompt_in_cache(
    initial_prompt_set, image_id, cache
)

Performs search over the cache to see if prior used prompts are subset of this one.

Source code in inference/models/sam3/visual_segmentation.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def find_prior_prompt_in_cache(
    initial_prompt_set: Sam2PromptSet,
    image_id: str,
    cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
    """
    Performs search over the cache to see if prior used prompts are subset of this one.
    """

    logits_for_image = [cache[k] for k in cache if k[0] == image_id]
    maxed_size = 0
    best_match: Optional[np.ndarray] = None
    desired_size = initial_prompt_set.num_points() - 1
    for cached_dict in logits_for_image[::-1]:
        logits = cached_dict["logits"]
        prompt_set: Sam2PromptSet = cached_dict["prompt_set"]
        is_viable = is_prompt_strict_subset(prompt_set, initial_prompt_set)
        if not is_viable:
            continue

        size = prompt_set.num_points()
        # short circuit search if we find prompt with one less point (most recent possible mask)
        if size == desired_size:
            return logits
        if size >= maxed_size:
            maxed_size = size
            best_match = logits

    return best_match

hash_prompt_set

hash_prompt_set(image_id, prompt_set)

Computes unique hash from a prompt set.

Source code in inference/models/sam3/visual_segmentation.py
422
423
424
425
426
def hash_prompt_set(image_id: str, prompt_set: Sam2PromptSet) -> Tuple[str, str]:
    """Computes unique hash from a prompt set."""
    md5_hash = hashlib.md5()
    md5_hash.update(str(prompt_set).encode("utf-8"))
    return image_id, md5_hash.hexdigest()[:12]

maybe_load_low_res_logits_from_cache

maybe_load_low_res_logits_from_cache(
    image_id, prompt_set, cache
)

Loads prior masks from the cache by searching over possibel prior prompts.

Source code in inference/models/sam3/visual_segmentation.py
429
430
431
432
433
434
435
436
437
438
def maybe_load_low_res_logits_from_cache(
    image_id: str,
    prompt_set: Sam2PromptSet,
    cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
    "Loads prior masks from the cache by searching over possibel prior prompts."
    prompts = prompt_set.prompts
    if not prompts:
        return None
    return find_prior_prompt_in_cache(prompt_set, image_id, cache)

pad_points

pad_points(args)

Pad arguments to be passed to sam2 model with not_a_point label (-1). This is necessary when there are multiple prompts per image so that a tensor can be created.

Also pads empty point lists with a dummy non-point entry.

Source code in inference/models/sam3/visual_segmentation.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def pad_points(args: Dict[str, Any]) -> Dict[str, Any]:
    """
    Pad arguments to be passed to sam2 model with not_a_point label (-1).
    This is necessary when there are multiple prompts per image so that a tensor can be created.


    Also pads empty point lists with a dummy non-point entry.
    """
    args = copy.deepcopy(args)
    if args["point_coords"] is not None:
        max_len = max(max(len(prompt) for prompt in args["point_coords"]), 1)
        for prompt in args["point_coords"]:
            for _ in range(max_len - len(prompt)):
                prompt.append([0, 0])
        for label in args["point_labels"]:
            for _ in range(max_len - len(label)):
                label.append(-1)
    else:
        if args["point_labels"] is not None:
            raise ValueError(
                "Can't have point labels without corresponding point coordinates"
            )
    return args

models/sam3_3d

inference.models.sam3_3d.segment_anything_3d

Classes

Sam3_3D_ObjectsPipelineSingleton

Singleton to cache the heavy 3D pipeline initialization.

Source code in inference/models/sam3_3d/segment_anything_3d.py
237
238
239
240
241
242
243
244
245
246
247
248
249
class Sam3_3D_ObjectsPipelineSingleton:
    """Singleton to cache the heavy 3D pipeline initialization."""

    _instances = weakref.WeakValueDictionary()
    _lock = Lock()

    def __new__(cls, config_key: str):
        with cls._lock:
            if config_key not in cls._instances:
                instance = super().__new__(cls)
                instance.config_key = config_key
                cls._instances[config_key] = instance
            return cls._instances[config_key]

SegmentAnything3_3D_Objects

Bases: RoboflowCoreModel

Source code in inference/models/sam3_3d/segment_anything_3d.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class SegmentAnything3_3D_Objects(RoboflowCoreModel):

    task_type = "3d-reconstruction"

    def __init__(
        self,
        *args,
        model_id: str = "sam3-3d-objects",
        torch_compile: bool = False,
        compile_res: int = 518,
        **kwargs,
    ):
        super().__init__(model_id=model_id, **kwargs)

        self.cache_dir = Path(get_cache_dir(model_id=self.endpoint))

        tdfy_dir = files(tdfy.sam3d_v1)
        pipeline_config_path = tdfy_dir / "checkpoints_configs" / "pipeline.yaml"
        moge_checkpoint_path = self.cache_dir / "moge-vitl.pth"
        ss_generator_checkpoint_path = self.cache_dir / "ss_generator.ckpt"
        slat_generator_checkpoint_path = self.cache_dir / "slat_generator.ckpt"
        ss_decoder_checkpoint_path = self.cache_dir / "ss_decoder.ckpt"
        slat_decoder_checkpoint_path = self.cache_dir / "slat_decoder_gs.ckpt"
        slat_decodergs4_checkpoint_path = self.cache_dir / "slat_decoder_gs_4.ckpt"
        slat_decoder_mesh_checkpoint_path = self.cache_dir / "slat_decoder_mesh.pt"
        dinov2_ckpt_path = self.cache_dir / "dinov2_vitl14_reg4_pretrain.pth"

        config_key = f"{DEVICE}_{pipeline_config_path}"
        singleton = Sam3_3D_ObjectsPipelineSingleton(config_key)

        if not hasattr(singleton, "pipeline"):
            self.pipeline_config = OmegaConf.load(str(pipeline_config_path))
            self.pipeline_config["device"] = DEVICE
            self.pipeline_config["workspace_dir"] = str(tdfy_dir)
            self.pipeline_config["compile_model"] = torch_compile
            self.pipeline_config["compile_res"] = compile_res
            self.pipeline_config["depth_model"]["model"][
                "pretrained_model_name_or_path"
            ] = str(moge_checkpoint_path)
            self.pipeline_config["ss_generator_ckpt_path"] = str(
                ss_generator_checkpoint_path
            )
            self.pipeline_config["slat_generator_ckpt_path"] = str(
                slat_generator_checkpoint_path
            )
            self.pipeline_config["ss_decoder_ckpt_path"] = str(
                ss_decoder_checkpoint_path
            )
            self.pipeline_config["slat_decoder_gs_ckpt_path"] = str(
                slat_decoder_checkpoint_path
            )
            self.pipeline_config["slat_decoder_gs_4_ckpt_path"] = str(
                slat_decodergs4_checkpoint_path
            )
            self.pipeline_config["slat_decoder_mesh_ckpt_path"] = str(
                slat_decoder_mesh_checkpoint_path
            )
            self.pipeline_config["dinov2_ckpt_path"] = str(dinov2_ckpt_path)
            singleton.pipeline = instantiate(self.pipeline_config)

        # Reference the singleton's pipeline
        self.pipeline = singleton.pipeline
        self._state_lock = Lock()

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["environment.json"].
        """
        return [
            "moge-vitl.pth",
            "ss_generator.ckpt",
            "slat_generator.ckpt",
            "ss_decoder.ckpt",
            "slat_decoder_gs.ckpt",
            "slat_decoder_gs_4.ckpt",
            "slat_decoder_mesh.pt",
        ]

    def download_model_from_roboflow_api(self) -> None:
        """Override parent method to use streaming downloads for large SAM3_3D model files."""
        lock_dir = MODEL_CACHE_DIR + "/_file_locks"
        os.makedirs(lock_dir, exist_ok=True)
        lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
        lock = FileLock(lock_file, timeout=120)
        with lock:
            api_data = get_roboflow_model_data(
                api_key=self.api_key,
                model_id="sam3-3d-weights-vc6vz/1",
                endpoint_type=ModelEndpointType.ORT,
                device_id=self.device_id,
            )["ort"]
            if "weights" not in api_data:
                raise ModelArtefactError(
                    f"`weights` key not available in Roboflow API response while downloading model weights."
                )
            for weights_url_key in api_data["weights"]:
                weights_url = api_data["weights"][weights_url_key]
                filename = weights_url.split("?")[0].split("/")[-1]
                stream_url_to_cache(
                    url=weights_url,
                    filename=filename,
                    model_id=self.endpoint,
                )

    def infer_from_request(
        self, request: Sam3_3D_Objects_InferenceRequest
    ) -> Sam3_3D_Objects_Response:
        with self._state_lock:
            t1 = perf_counter()
            raw_result = self.create_3d(**request.model_dump())
            inference_time = perf_counter() - t1
            return convert_3d_objects_result_to_api_response(
                raw_result=raw_result,
                inference_time=inference_time,
            )

    def create_3d(
        self,
        image: Optional[InferenceRequestImage],
        mask_input: Optional[Any] = None,
        *,
        output_meshes: bool = True,
        output_scene: bool = True,
        with_mesh_postprocess: bool = True,
        with_texture_baking: bool = True,
        use_distillations: bool = False,
        **kwargs,
    ):
        """
        Generate 3D from image and mask(s).

        Args:
            image: Input image
            mask_input: Mask in any supported format:
                - np.ndarray (H,W) or (N,H,W): Binary mask(s)
                - List[float]: COCO polygon [x1,y1,x2,y2,...]
                - List[List[float]]: Multiple polygons
                - Dict with 'counts'/'size': RLE mask
                - List[Dict]: Multiple RLE masks
        """
        with torch.inference_mode():
            if image is None or mask_input is None:
                raise ValueError("Must provide image and mask!")

            image_np = load_image_rgb(image)
            if image_np.dtype != np.uint8:
                if image_np.max() <= 1:
                    image_np = (image_np * 255).astype(np.uint8)
                else:
                    image_np = image_np.astype(np.uint8)
            image_shape = (image_np.shape[0], image_np.shape[1])

            if _is_single_mask_input(mask_input):
                masks = [convert_mask_to_binary(mask_input, image_shape)]
            elif isinstance(mask_input, np.ndarray) and mask_input.ndim == 3:
                masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]
            else:
                masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]

            # NOTE: mesh depends on gaussian, so we always decode gaussian
            decode_formats = ["gaussian"]
            if output_meshes:
                decode_formats.append("mesh")

            outputs = []
            for mask in masks:
                result = self.pipeline.run(
                    image=image_np,
                    mask=mask,
                    decode_formats=decode_formats,
                    with_mesh_postprocess=with_mesh_postprocess,
                    with_texture_baking=with_texture_baking,
                    use_stage1_distillation=use_distillations,
                    use_stage2_distillation=use_distillations,
                )
                outputs.append(result)

            if len(outputs) == 1:
                result = outputs[0]
                scene_gs = (
                    ready_gaussian_for_video_rendering(result["gs"])
                    if output_scene
                    else None
                )
                glb = result["glb"] if output_meshes else None
                return {
                    "gs": scene_gs,
                    "glb": glb,
                    "objects": outputs,
                }
            else:
                if output_scene:
                    scene_gs = make_scene(*outputs)
                    scene_gs = ready_gaussian_for_video_rendering(scene_gs)
                    scene_gs = apply_gaussian_view_correction(scene_gs)
                    scene_glb = make_scene_glb(*outputs) if output_meshes else None
                else:
                    scene_gs = None
                    scene_glb = None
                return {
                    "gs": scene_gs,
                    "glb": scene_glb,
                    "objects": outputs,
                }
Functions
create_3d
create_3d(
    image,
    mask_input=None,
    *,
    output_meshes=True,
    output_scene=True,
    with_mesh_postprocess=True,
    with_texture_baking=True,
    use_distillations=False,
    **kwargs
)

Generate 3D from image and mask(s).

Parameters:

Name Type Description Default
image Optional[InferenceRequestImage]

Input image

required
mask_input Optional[Any]

Mask in any supported format: - np.ndarray (H,W) or (N,H,W): Binary mask(s) - List[float]: COCO polygon [x1,y1,x2,y2,...] - List[List[float]]: Multiple polygons - Dict with 'counts'/'size': RLE mask - List[Dict]: Multiple RLE masks

None
Source code in inference/models/sam3_3d/segment_anything_3d.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def create_3d(
    self,
    image: Optional[InferenceRequestImage],
    mask_input: Optional[Any] = None,
    *,
    output_meshes: bool = True,
    output_scene: bool = True,
    with_mesh_postprocess: bool = True,
    with_texture_baking: bool = True,
    use_distillations: bool = False,
    **kwargs,
):
    """
    Generate 3D from image and mask(s).

    Args:
        image: Input image
        mask_input: Mask in any supported format:
            - np.ndarray (H,W) or (N,H,W): Binary mask(s)
            - List[float]: COCO polygon [x1,y1,x2,y2,...]
            - List[List[float]]: Multiple polygons
            - Dict with 'counts'/'size': RLE mask
            - List[Dict]: Multiple RLE masks
    """
    with torch.inference_mode():
        if image is None or mask_input is None:
            raise ValueError("Must provide image and mask!")

        image_np = load_image_rgb(image)
        if image_np.dtype != np.uint8:
            if image_np.max() <= 1:
                image_np = (image_np * 255).astype(np.uint8)
            else:
                image_np = image_np.astype(np.uint8)
        image_shape = (image_np.shape[0], image_np.shape[1])

        if _is_single_mask_input(mask_input):
            masks = [convert_mask_to_binary(mask_input, image_shape)]
        elif isinstance(mask_input, np.ndarray) and mask_input.ndim == 3:
            masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]
        else:
            masks = [convert_mask_to_binary(m, image_shape) for m in mask_input]

        # NOTE: mesh depends on gaussian, so we always decode gaussian
        decode_formats = ["gaussian"]
        if output_meshes:
            decode_formats.append("mesh")

        outputs = []
        for mask in masks:
            result = self.pipeline.run(
                image=image_np,
                mask=mask,
                decode_formats=decode_formats,
                with_mesh_postprocess=with_mesh_postprocess,
                with_texture_baking=with_texture_baking,
                use_stage1_distillation=use_distillations,
                use_stage2_distillation=use_distillations,
            )
            outputs.append(result)

        if len(outputs) == 1:
            result = outputs[0]
            scene_gs = (
                ready_gaussian_for_video_rendering(result["gs"])
                if output_scene
                else None
            )
            glb = result["glb"] if output_meshes else None
            return {
                "gs": scene_gs,
                "glb": glb,
                "objects": outputs,
            }
        else:
            if output_scene:
                scene_gs = make_scene(*outputs)
                scene_gs = ready_gaussian_for_video_rendering(scene_gs)
                scene_gs = apply_gaussian_view_correction(scene_gs)
                scene_glb = make_scene_glb(*outputs) if output_meshes else None
            else:
                scene_gs = None
                scene_glb = None
            return {
                "gs": scene_gs,
                "glb": scene_glb,
                "objects": outputs,
            }
download_model_from_roboflow_api
download_model_from_roboflow_api()

Override parent method to use streaming downloads for large SAM3_3D model files.

Source code in inference/models/sam3_3d/segment_anything_3d.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def download_model_from_roboflow_api(self) -> None:
    """Override parent method to use streaming downloads for large SAM3_3D model files."""
    lock_dir = MODEL_CACHE_DIR + "/_file_locks"
    os.makedirs(lock_dir, exist_ok=True)
    lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock")
    lock = FileLock(lock_file, timeout=120)
    with lock:
        api_data = get_roboflow_model_data(
            api_key=self.api_key,
            model_id="sam3-3d-weights-vc6vz/1",
            endpoint_type=ModelEndpointType.ORT,
            device_id=self.device_id,
        )["ort"]
        if "weights" not in api_data:
            raise ModelArtefactError(
                f"`weights` key not available in Roboflow API response while downloading model weights."
            )
        for weights_url_key in api_data["weights"]:
            weights_url = api_data["weights"][weights_url_key]
            filename = weights_url.split("?")[0].split("/")[-1]
            stream_url_to_cache(
                url=weights_url,
                filename=filename,
                model_id=self.endpoint,
            )
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["environment.json"].

Source code in inference/models/sam3_3d/segment_anything_3d.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["environment.json"].
    """
    return [
        "moge-vitl.pth",
        "ss_generator.ckpt",
        "slat_generator.ckpt",
        "ss_decoder.ckpt",
        "slat_decoder_gs.ckpt",
        "slat_decoder_gs_4.ckpt",
        "slat_decoder_mesh.pt",
    ]

Functions

apply_gaussian_view_correction

apply_gaussian_view_correction(scene_gs)

Apply view correction to Gaussian scene to match GLB orientation. Used for combined scene PLY.

Source code in inference/models/sam3_3d/segment_anything_3d.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def apply_gaussian_view_correction(scene_gs):
    """
    Apply view correction to Gaussian scene to match GLB orientation.
    Used for combined scene PLY.
    """
    xyz = scene_gs.get_xyz
    device = xyz.device
    dtype = xyz.dtype

    R_view_zup = torch.tensor(
        [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
        device=device,
        dtype=dtype,
    )

    new_xyz = xyz @ R_view_zup
    scene_gs.from_xyz(new_xyz)

    q_correction = matrix_to_quaternion(R_view_zup.unsqueeze(0)).squeeze(0)
    old_rotations = scene_gs.get_rotation
    new_rotations = quaternion_multiply(
        q_correction.unsqueeze(0).expand(old_rotations.shape[0], -1), old_rotations
    )
    scene_gs.from_rotation(new_rotations)

    return scene_gs

convert_mask_to_binary

convert_mask_to_binary(mask_input, image_shape)

Convert polygon, RLE, or binary mask to binary mask (H, W) with values 0/255.

Source code in inference/models/sam3_3d/segment_anything_3d.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def convert_mask_to_binary(mask_input: Any, image_shape: Tuple[int, int]) -> np.ndarray:
    """Convert polygon, RLE, or binary mask to binary mask (H, W) with values 0/255."""
    height, width = image_shape

    if isinstance(mask_input, np.ndarray):
        return _normalize_binary_mask(mask_input, image_shape)

    if isinstance(mask_input, Image.Image):
        return _normalize_binary_mask(np.array(mask_input.convert("L")), image_shape)

    if isinstance(mask_input, dict) and "counts" in mask_input:
        if not PYCOCOTOOLS_AVAILABLE:
            raise ImportError(
                "pycocotools required for RLE. Install: pip install pycocotools"
            )
        rle = dict(mask_input)
        if isinstance(rle.get("counts"), str):
            rle["counts"] = rle["counts"].encode("utf-8")
        return _normalize_binary_mask(mask_utils.decode(rle), image_shape)

    if isinstance(mask_input, list):
        points = _parse_polygon_to_points(mask_input)
        if not points or len(points) < 3:
            return np.zeros((height, width), dtype=np.uint8)
        mask = Image.new("L", (width, height), 0)
        ImageDraw.Draw(mask).polygon(points, outline=255, fill=255)
        return np.array(mask, dtype=np.uint8)

    raise TypeError(f"Unsupported mask type: {type(mask_input)}")

make_scene_glb

make_scene_glb(*outputs)

Combine multiple GLB meshes into a single scene. Applies layout transforms and a final view correction rotation.

Source code in inference/models/sam3_3d/segment_anything_3d.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def make_scene_glb(*outputs):
    """
    Combine multiple GLB meshes into a single scene.
    Applies layout transforms and a final view correction rotation.
    """
    scene = trimesh.Scene()

    for i, output in enumerate(outputs):
        glb = output["glb"]
        glb = glb.copy()

        glb = transform_glb_to_world(
            glb,
            output["rotation"],
            output["translation"],
            output["scale"],
        )
        scene.add_geometry(glb, node_name=f"object_{i}")

    R_view = np.array([[-1, 0, 0], [0, 0, -1], [0, -1, 0]], dtype=np.float32)
    for geom_name in scene.geometry:
        mesh = scene.geometry[geom_name]
        mesh.vertices = (mesh.vertices.astype(np.float32)) @ R_view
        if (
            hasattr(mesh, "vertex_normals")
            and mesh.vertex_normals is not None
            and len(mesh.vertex_normals) > 0
        ):
            mesh.vertex_normals = (mesh.vertex_normals.astype(np.float32)) @ R_view

    return scene

prepare_individual_object_for_export

prepare_individual_object_for_export(gs)

Prepare an individual object Gaussian for PLY export.

Source code in inference/models/sam3_3d/segment_anything_3d.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def prepare_individual_object_for_export(gs):
    """
    Prepare an individual object Gaussian for PLY export.
    """
    from copy import deepcopy

    gs_copy = deepcopy(gs)
    gs_copy = ready_gaussian_for_video_rendering(gs_copy)

    xyz = gs_copy.get_xyz
    device = xyz.device
    dtype = xyz.dtype

    R_view = torch.tensor(
        [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], device=device, dtype=dtype
    )

    new_xyz = xyz @ R_view
    gs_copy.from_xyz(new_xyz)

    q_correction = matrix_to_quaternion(R_view.unsqueeze(0)).squeeze(0)
    old_rotations = gs_copy.get_rotation
    new_rotations = quaternion_multiply(
        q_correction.unsqueeze(0).expand(old_rotations.shape[0], -1), old_rotations
    )
    gs_copy.from_rotation(new_rotations)

    return gs_copy

transform_glb_to_world

transform_glb_to_world(
    glb_mesh, rotation, translation, scale
)

Transform a GLB mesh from local to world coordinates.

Source code in inference/models/sam3_3d/segment_anything_3d.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def transform_glb_to_world(glb_mesh, rotation, translation, scale):
    """
    Transform a GLB mesh from local to world coordinates.
    """
    quat = rotation.squeeze()
    quat_normalized = quat / quat.norm()
    R_layout = quaternion_to_matrix(quat_normalized).cpu().numpy()
    t = translation.squeeze().cpu().numpy()
    s = scale.squeeze().cpu().numpy()[0]

    z_to_y_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32)
    y_to_z_up = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32)

    verts = glb_mesh.vertices.copy().astype(np.float32)

    verts = verts @ y_to_z_up

    verts = verts * s
    verts = verts @ R_layout
    verts = verts + t

    verts = verts @ z_to_y_up

    glb_mesh.vertices = verts

    if (
        hasattr(glb_mesh, "vertex_normals")
        and glb_mesh.vertex_normals is not None
        and len(glb_mesh.vertex_normals) > 0
    ):
        normals = glb_mesh.vertex_normals.copy().astype(np.float32)
        normals = normals @ y_to_z_up
        normals = normals @ R_layout
        normals = normals @ z_to_y_up
        glb_mesh.vertex_normals = normals

    return glb_mesh

models/vit

inference.models.vit.vit_classification

Classes

VitClassification

Bases: ClassificationBaseOnnxRoboflowInferenceModel

VitClassification handles classification inference for Vision Transformer (ViT) models using ONNX.

Inherits

Attributes:

Name Type Description
multiclass bool

A flag that specifies if the model should handle multiclass classification.

Source code in inference/models/vit/vit_classification.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel):
    """VitClassification handles classification inference
    for Vision Transformer (ViT) models using ONNX.

    Inherits:
        ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference.
        ClassificationMixin: Mixin class providing classification-specific methods.

    Attributes:
        multiclass (bool): A flag that specifies if the model should handle multiclass classification.
    """

    preprocess_means = [0.5, 0.5, 0.5]
    preprocess_stds = [0.5, 0.5, 0.5]

    def __init__(self, *args, **kwargs):
        """Initializes the VitClassification instance.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.multiclass = self.environment.get("MULTICLASS", False)

    @property
    def weights_file(self) -> str:
        """Determines the weights file to be used based on the availability of AWS keys.

        If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'.
        Otherwise, it returns the path to 'best.onnx'.

        Returns:
            str: Path to the weights file.
        """
        if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA:
            return "weights.onnx"
        else:
            return "best.onnx"
Attributes
weights_file property
weights_file

Determines the weights file to be used based on the availability of AWS keys.

If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. Otherwise, it returns the path to 'best.onnx'.

Returns:

Name Type Description
str str

Path to the weights file.

Functions
__init__
__init__(*args, **kwargs)

Initializes the VitClassification instance.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/vit/vit_classification.py
22
23
24
25
26
27
28
29
30
def __init__(self, *args, **kwargs):
    """Initializes the VitClassification instance.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self.multiclass = self.environment.get("MULTICLASS", False)

models/yolact

inference.models.yolact.yolact_instance_segmentation

Classes

YOLACT

Bases: OnnxRoboflowInferenceModel

Roboflow ONNX Object detection model (Implements an object detection specific infer method)

Source code in inference/models/yolact/yolact_instance_segmentation.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class YOLACT(OnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model (Implements an object detection specific infer method)"""

    task_type = "instance-segmentation"

    @property
    def weights_file(self) -> str:
        """Gets the weights file.

        Returns:
            str: Path to the weights file.
        """
        return "weights.onnx"

    def infer(
        self,
        image: Any,
        class_agnostic_nms: bool = False,
        confidence: float = 0.5,
        iou_threshold: float = 0.5,
        max_candidates: int = 3000,
        max_detections: int = 300,
        return_image_dims: bool = False,
        **kwargs,
    ) -> List[List[dict]]:
        """
        Performs instance segmentation inference on a given image, post-processes the results,
        and returns the segmented instances as dictionaries containing their properties.

        Args:
            image (Any): The image or list of images to segment.
                - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
            class_agnostic_nms (bool, optional): Whether to perform class-agnostic non-max suppression. Defaults to False.
            confidence (float, optional): Confidence threshold for filtering weak detections. Defaults to 0.5.
            iou_threshold (float, optional): Intersection-over-union threshold for non-max suppression. Defaults to 0.5.
            max_candidates (int, optional): Maximum number of candidate detections to consider. Defaults to 3000.
            max_detections (int, optional): Maximum number of detections to return after non-max suppression. Defaults to 300.
            return_image_dims (bool, optional): Whether to return the dimensions of the input image(s). Defaults to False.
            **kwargs: Additional keyword arguments.

        Returns:
            List[List[dict]]: Each list contains dictionaries of segmented instances for a given image. Each dictionary contains:
                - x, y: Center coordinates of the instance.
                - width, height: Width and height of the bounding box around the instance.
                - class: Name of the detected class.
                - confidence: Confidence score of the detection.
                - points: List of points describing the segmented mask's boundary.
                - class_id: ID corresponding to the detected class.
            If `return_image_dims` is True, the function returns a tuple where the first element is the list of detections and the
            second element is the list of image dimensions.

        Notes:
            - The function supports processing multiple images in a batch.
            - If an input list of images is provided, the function returns a list of lists,
              where each inner list corresponds to the detections for a specific image.
            - The function internally uses an ONNX model for inference.
        """
        return super().infer(
            image,
            class_agnostic_nms=class_agnostic_nms,
            confidence=confidence,
            iou_threshold=iou_threshold,
            max_candidates=max_candidates,
            max_detections=max_detections,
            return_image_dims=return_image_dims,
            **kwargs,
        )

    def preprocess(
        self, image: Any, **kwargs
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        if isinstance(image, list):
            imgs_with_dims = [self.preproc_image(i) for i in image]
            imgs, img_dims = zip(*imgs_with_dims)
            if isinstance(imgs[0], np.ndarray):
                img_in = np.concatenate(imgs, axis=0)
            elif USE_PYTORCH_FOR_PREPROCESSING:
                img_in = torch.cat(imgs, dim=0)
            else:
                raise ValueError(
                    f"Received a list of images of unknown type, {type(imgs[0])}; "
                    "This is most likely a bug. Contact Roboflow team through github issues "
                    "(https://github.com/roboflow/inference/issues) providing full context of the problem"
                )
        else:
            img_in, img_dims = self.preproc_image(image)
            img_dims = [img_dims]

        # IN BGR order (for some reason)
        mean = (103.94, 116.78, 123.68)
        std = (57.38, 57.12, 58.40)

        if isinstance(img_in, np.ndarray):
            img_in = img_in.astype(np.float32)
        elif USE_PYTORCH_FOR_PREPROCESSING:
            img_in = img_in.float()
        else:
            raise ValueError(
                f"Received an image of unknown type, {type(img_in)}; "
                "This is most likely a bug. Contact Roboflow team through github issues "
                "(https://github.com/roboflow/inference/issues) providing full context of the problem"
            )

        # Our channels are RGB, so apply mean and std accordingly
        img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[2]) / std[2]
        img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1]
        img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[0]) / std[0]

        return img_in, PreprocessReturnMetadata(
            {
                "img_dims": img_dims,
                "im_shape": img_in.shape,
            }
        )

    def predict(
        self, img_in: np.ndarray, **kwargs
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        with self._session_lock:
            return run_session_via_iobinding(self.onnx_session, self.input_name, img_in)

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        loc_data = np.float32(predictions[0])
        conf_data = np.float32(predictions[1])
        mask_data = np.float32(predictions[2])
        prior_data = np.float32(predictions[3])
        proto_data = np.float32(predictions[4])

        batch_size = loc_data.shape[0]
        num_priors = prior_data.shape[0]

        boxes = np.zeros((batch_size, num_priors, 4))
        for batch_idx in range(batch_size):
            boxes[batch_idx, :, :] = self.decode_predicted_bboxes(
                loc_data[batch_idx], prior_data
            )

        conf_preds = np.reshape(
            conf_data, (batch_size, num_priors, self.num_classes + 1)
        )
        class_confs = conf_preds[:, :, 1:]  # remove background class
        box_confs = np.expand_dims(
            np.max(class_confs, axis=2), 2
        )  # get max conf for each box

        predictions = np.concatenate((boxes, box_confs, class_confs, mask_data), axis=2)

        img_in_shape = preprocess_return_metadata["im_shape"]
        predictions[:, :, 0] *= img_in_shape[2]
        predictions[:, :, 1] *= img_in_shape[3]
        predictions[:, :, 2] *= img_in_shape[2]
        predictions[:, :, 3] *= img_in_shape[3]
        predictions = w_np_non_max_suppression(
            predictions,
            conf_thresh=kwargs["confidence"],
            iou_thresh=kwargs["iou_threshold"],
            class_agnostic=kwargs["class_agnostic_nms"],
            max_detections=kwargs["max_detections"],
            max_candidate_detections=kwargs["max_candidates"],
            num_masks=32,
            box_format="xyxy",
        )
        predictions = np.array(predictions)
        batch_preds = []
        if predictions.shape != (1, 0):
            for batch_idx, img_dim in enumerate(preprocess_return_metadata["img_dims"]):
                boxes = predictions[batch_idx, :, :4]
                scores = predictions[batch_idx, :, 4]
                classes = predictions[batch_idx, :, 6]
                masks = predictions[batch_idx, :, 7:]
                proto = proto_data[batch_idx]
                decoded_masks = self.decode_masks(boxes, masks, proto, img_in_shape[2:])
                polys = masks2poly(decoded_masks)
                infer_shape = (self.img_size_w, self.img_size_h)
                boxes = post_process_bboxes(
                    [boxes], infer_shape, [img_dim], self.preproc, self.resize_method
                )[0]
                polys = post_process_polygons(
                    img_in_shape[2:],
                    polys,
                    img_dim,
                    self.preproc,
                    resize_method=self.resize_method,
                )
                preds = []
                for box, poly, score, cls in zip(boxes, polys, scores, classes):
                    confidence = float(score)
                    class_name = self.class_names[int(cls)]
                    points = [{"x": round(x, 1), "y": round(y, 1)} for (x, y) in poly]
                    pred = {
                        "x": round((box[2] + box[0]) / 2, 1),
                        "y": round((box[3] + box[1]) / 2, 1),
                        "width": int(box[2] - box[0]),
                        "height": int(box[3] - box[1]),
                        "class": class_name,
                        "confidence": round(confidence, 3),
                        "points": points,
                        "class_id": int(cls),
                    }
                    preds.append(pred)
                batch_preds.append(preds)
        else:
            batch_preds.append([])
        img_dims = preprocess_return_metadata["img_dims"]
        responses = self.make_response(batch_preds, img_dims, **kwargs)
        if kwargs["return_image_dims"]:
            return responses, preprocess_return_metadata["img_dims"]
        else:
            return responses

    def make_response(
        self,
        predictions: List[List[dict]],
        img_dims: List[Tuple[int, int]],
        class_filter: List[str] = None,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        """
        Constructs a list of InstanceSegmentationInferenceResponse objects based on the provided predictions
        and image dimensions, optionally filtering by class name.

        Args:
            predictions (List[List[dict]]): A list containing batch predictions, where each inner list contains
                dictionaries of segmented instances for a given image.
            img_dims (List[Tuple[int, int]]): List of tuples specifying the dimensions of each image in the format
                (height, width).
            class_filter (List[str], optional): A list of class names to filter the predictions by. If not provided,
                all predictions are included.

        Returns:
            List[InstanceSegmentationInferenceResponse]: A list of response objects, each containing the filtered
            predictions and corresponding image dimensions for a given image.

        Examples:
            >>> predictions = [[{"class_name": "cat", ...}, {"class_name": "dog", ...}], ...]
            >>> img_dims = [(300, 400), ...]
            >>> responses = make_response(predictions, img_dims, class_filter=["cat"])
            >>> len(responses[0].predictions)  # Only predictions with "cat" class are included
            1
        """
        responses = [
            InstanceSegmentationInferenceResponse(
                predictions=[
                    InstanceSegmentationPrediction(**p)
                    for p in batch_pred
                    if not class_filter or p["class_name"] in class_filter
                ],
                image=InferenceResponseImage(
                    width=img_dims[i][1], height=img_dims[i][0]
                ),
            )
            for i, batch_pred in enumerate(predictions)
        ]
        return responses

    def decode_masks(self, boxes, masks, proto, img_dim):
        """Decodes the masks from the given parameters.

        Args:
            boxes (np.array): Bounding boxes.
            masks (np.array): Masks.
            proto (np.array): Proto data.
            img_dim (tuple): Image dimensions.

        Returns:
            np.array: Decoded masks.
        """
        ret_mask = np.matmul(proto, np.transpose(masks))
        ret_mask = 1 / (1 + np.exp(-ret_mask))
        w, h, _ = ret_mask.shape
        gain = min(h / img_dim[0], w / img_dim[1])  # gain  = old / new
        pad = (w - img_dim[1] * gain) / 2, (h - img_dim[0] * gain) / 2  # wh padding
        top, left = int(pad[1]), int(pad[0])  # y, x
        bottom, right = int(h - pad[1]), int(w - pad[0])
        ret_mask = np.transpose(ret_mask, (2, 0, 1))
        ret_mask = ret_mask[:, top:bottom, left:right]
        if len(ret_mask.shape) == 2:
            ret_mask = np.expand_dims(ret_mask, axis=0)
        ret_mask = ret_mask.transpose((1, 2, 0))
        ret_mask = cv2.resize(ret_mask, img_dim, interpolation=cv2.INTER_LINEAR)
        if len(ret_mask.shape) == 2:
            ret_mask = np.expand_dims(ret_mask, axis=2)
        ret_mask = ret_mask.transpose((2, 0, 1))
        ret_mask = crop_mask(ret_mask, boxes)  # CHW
        ret_mask[ret_mask < 0.5] = 0

        return ret_mask

    def decode_predicted_bboxes(self, loc, priors):
        """Decode predicted bounding box coordinates using the scheme employed by Yolov2.

        Args:
            loc (np.array): The predicted bounding boxes of size [num_priors, 4].
            priors (np.array): The prior box coordinates with size [num_priors, 4].

        Returns:
            np.array: A tensor of decoded relative coordinates in point form with size [num_priors, 4].
        """

        variances = [0.1, 0.2]

        boxes = np.concatenate(
            [
                priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
                priors[:, 2:] * np.exp(loc[:, 2:] * variances[1]),
            ],
            1,
        )
        boxes[:, :2] -= boxes[:, 2:] / 2
        boxes[:, 2:] += boxes[:, :2]

        return boxes
Attributes
weights_file property
weights_file

Gets the weights file.

Returns:

Name Type Description
str str

Path to the weights file.

Functions
decode_masks
decode_masks(boxes, masks, proto, img_dim)

Decodes the masks from the given parameters.

Parameters:

Name Type Description Default
boxes array

Bounding boxes.

required
masks array

Masks.

required
proto array

Proto data.

required
img_dim tuple

Image dimensions.

required

Returns:

Type Description

np.array: Decoded masks.

Source code in inference/models/yolact/yolact_instance_segmentation.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def decode_masks(self, boxes, masks, proto, img_dim):
    """Decodes the masks from the given parameters.

    Args:
        boxes (np.array): Bounding boxes.
        masks (np.array): Masks.
        proto (np.array): Proto data.
        img_dim (tuple): Image dimensions.

    Returns:
        np.array: Decoded masks.
    """
    ret_mask = np.matmul(proto, np.transpose(masks))
    ret_mask = 1 / (1 + np.exp(-ret_mask))
    w, h, _ = ret_mask.shape
    gain = min(h / img_dim[0], w / img_dim[1])  # gain  = old / new
    pad = (w - img_dim[1] * gain) / 2, (h - img_dim[0] * gain) / 2  # wh padding
    top, left = int(pad[1]), int(pad[0])  # y, x
    bottom, right = int(h - pad[1]), int(w - pad[0])
    ret_mask = np.transpose(ret_mask, (2, 0, 1))
    ret_mask = ret_mask[:, top:bottom, left:right]
    if len(ret_mask.shape) == 2:
        ret_mask = np.expand_dims(ret_mask, axis=0)
    ret_mask = ret_mask.transpose((1, 2, 0))
    ret_mask = cv2.resize(ret_mask, img_dim, interpolation=cv2.INTER_LINEAR)
    if len(ret_mask.shape) == 2:
        ret_mask = np.expand_dims(ret_mask, axis=2)
    ret_mask = ret_mask.transpose((2, 0, 1))
    ret_mask = crop_mask(ret_mask, boxes)  # CHW
    ret_mask[ret_mask < 0.5] = 0

    return ret_mask
decode_predicted_bboxes
decode_predicted_bboxes(loc, priors)

Decode predicted bounding box coordinates using the scheme employed by Yolov2.

Parameters:

Name Type Description Default
loc array

The predicted bounding boxes of size [num_priors, 4].

required
priors array

The prior box coordinates with size [num_priors, 4].

required

Returns:

Type Description

np.array: A tensor of decoded relative coordinates in point form with size [num_priors, 4].

Source code in inference/models/yolact/yolact_instance_segmentation.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def decode_predicted_bboxes(self, loc, priors):
    """Decode predicted bounding box coordinates using the scheme employed by Yolov2.

    Args:
        loc (np.array): The predicted bounding boxes of size [num_priors, 4].
        priors (np.array): The prior box coordinates with size [num_priors, 4].

    Returns:
        np.array: A tensor of decoded relative coordinates in point form with size [num_priors, 4].
    """

    variances = [0.1, 0.2]

    boxes = np.concatenate(
        [
            priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
            priors[:, 2:] * np.exp(loc[:, 2:] * variances[1]),
        ],
        1,
    )
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]

    return boxes
infer
infer(
    image,
    class_agnostic_nms=False,
    confidence=0.5,
    iou_threshold=0.5,
    max_candidates=3000,
    max_detections=300,
    return_image_dims=False,
    **kwargs
)

Performs instance segmentation inference on a given image, post-processes the results, and returns the segmented instances as dictionaries containing their properties.

Parameters:

Name Type Description Default
image Any

The image or list of images to segment. - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.

required
class_agnostic_nms bool

Whether to perform class-agnostic non-max suppression. Defaults to False.

False
confidence float

Confidence threshold for filtering weak detections. Defaults to 0.5.

0.5
iou_threshold float

Intersection-over-union threshold for non-max suppression. Defaults to 0.5.

0.5
max_candidates int

Maximum number of candidate detections to consider. Defaults to 3000.

3000
max_detections int

Maximum number of detections to return after non-max suppression. Defaults to 300.

300
return_image_dims bool

Whether to return the dimensions of the input image(s). Defaults to False.

False
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
List[List[dict]]

List[List[dict]]: Each list contains dictionaries of segmented instances for a given image. Each dictionary contains: - x, y: Center coordinates of the instance. - width, height: Width and height of the bounding box around the instance. - class: Name of the detected class. - confidence: Confidence score of the detection. - points: List of points describing the segmented mask's boundary. - class_id: ID corresponding to the detected class.

List[List[dict]]

If return_image_dims is True, the function returns a tuple where the first element is the list of detections and the

List[List[dict]]

second element is the list of image dimensions.

Notes
  • The function supports processing multiple images in a batch.
  • If an input list of images is provided, the function returns a list of lists, where each inner list corresponds to the detections for a specific image.
  • The function internally uses an ONNX model for inference.
Source code in inference/models/yolact/yolact_instance_segmentation.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def infer(
    self,
    image: Any,
    class_agnostic_nms: bool = False,
    confidence: float = 0.5,
    iou_threshold: float = 0.5,
    max_candidates: int = 3000,
    max_detections: int = 300,
    return_image_dims: bool = False,
    **kwargs,
) -> List[List[dict]]:
    """
    Performs instance segmentation inference on a given image, post-processes the results,
    and returns the segmented instances as dictionaries containing their properties.

    Args:
        image (Any): The image or list of images to segment.
            - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        class_agnostic_nms (bool, optional): Whether to perform class-agnostic non-max suppression. Defaults to False.
        confidence (float, optional): Confidence threshold for filtering weak detections. Defaults to 0.5.
        iou_threshold (float, optional): Intersection-over-union threshold for non-max suppression. Defaults to 0.5.
        max_candidates (int, optional): Maximum number of candidate detections to consider. Defaults to 3000.
        max_detections (int, optional): Maximum number of detections to return after non-max suppression. Defaults to 300.
        return_image_dims (bool, optional): Whether to return the dimensions of the input image(s). Defaults to False.
        **kwargs: Additional keyword arguments.

    Returns:
        List[List[dict]]: Each list contains dictionaries of segmented instances for a given image. Each dictionary contains:
            - x, y: Center coordinates of the instance.
            - width, height: Width and height of the bounding box around the instance.
            - class: Name of the detected class.
            - confidence: Confidence score of the detection.
            - points: List of points describing the segmented mask's boundary.
            - class_id: ID corresponding to the detected class.
        If `return_image_dims` is True, the function returns a tuple where the first element is the list of detections and the
        second element is the list of image dimensions.

    Notes:
        - The function supports processing multiple images in a batch.
        - If an input list of images is provided, the function returns a list of lists,
          where each inner list corresponds to the detections for a specific image.
        - The function internally uses an ONNX model for inference.
    """
    return super().infer(
        image,
        class_agnostic_nms=class_agnostic_nms,
        confidence=confidence,
        iou_threshold=iou_threshold,
        max_candidates=max_candidates,
        max_detections=max_detections,
        return_image_dims=return_image_dims,
        **kwargs,
    )
make_response
make_response(
    predictions, img_dims, class_filter=None, **kwargs
)

Constructs a list of InstanceSegmentationInferenceResponse objects based on the provided predictions and image dimensions, optionally filtering by class name.

Parameters:

Name Type Description Default
predictions List[List[dict]]

A list containing batch predictions, where each inner list contains dictionaries of segmented instances for a given image.

required
img_dims List[Tuple[int, int]]

List of tuples specifying the dimensions of each image in the format (height, width).

required
class_filter List[str]

A list of class names to filter the predictions by. If not provided, all predictions are included.

None

Returns:

Type Description
List[InstanceSegmentationInferenceResponse]

List[InstanceSegmentationInferenceResponse]: A list of response objects, each containing the filtered

List[InstanceSegmentationInferenceResponse]

predictions and corresponding image dimensions for a given image.

Examples:

>>> predictions = [[{"class_name": "cat", ...}, {"class_name": "dog", ...}], ...]
>>> img_dims = [(300, 400), ...]
>>> responses = make_response(predictions, img_dims, class_filter=["cat"])
>>> len(responses[0].predictions)  # Only predictions with "cat" class are included
1
Source code in inference/models/yolact/yolact_instance_segmentation.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def make_response(
    self,
    predictions: List[List[dict]],
    img_dims: List[Tuple[int, int]],
    class_filter: List[str] = None,
    **kwargs,
) -> List[InstanceSegmentationInferenceResponse]:
    """
    Constructs a list of InstanceSegmentationInferenceResponse objects based on the provided predictions
    and image dimensions, optionally filtering by class name.

    Args:
        predictions (List[List[dict]]): A list containing batch predictions, where each inner list contains
            dictionaries of segmented instances for a given image.
        img_dims (List[Tuple[int, int]]): List of tuples specifying the dimensions of each image in the format
            (height, width).
        class_filter (List[str], optional): A list of class names to filter the predictions by. If not provided,
            all predictions are included.

    Returns:
        List[InstanceSegmentationInferenceResponse]: A list of response objects, each containing the filtered
        predictions and corresponding image dimensions for a given image.

    Examples:
        >>> predictions = [[{"class_name": "cat", ...}, {"class_name": "dog", ...}], ...]
        >>> img_dims = [(300, 400), ...]
        >>> responses = make_response(predictions, img_dims, class_filter=["cat"])
        >>> len(responses[0].predictions)  # Only predictions with "cat" class are included
        1
    """
    responses = [
        InstanceSegmentationInferenceResponse(
            predictions=[
                InstanceSegmentationPrediction(**p)
                for p in batch_pred
                if not class_filter or p["class_name"] in class_filter
            ],
            image=InferenceResponseImage(
                width=img_dims[i][1], height=img_dims[i][0]
            ),
        )
        for i, batch_pred in enumerate(predictions)
    ]
    return responses

Functions

models/yolo26

inference.models.yolo26.yolo26_instance_segmentation

Classes

YOLO26InstanceSegmentation

Bases: YOLOv11InstanceSegmentation

YOLO26 Instance Segmentation model with end-to-end ONNX output.

YOLO26 exports with NMS already applied, outputting: - predictions: (batch, num_detections, 38) where 38 = 6 + 32 mask coefficients Format: [x1, y1, x2, y2, confidence, class_index, mask_coeff_0, ..., mask_coeff_31] - protos: (batch, 32, H, W) mask prototypes

Source code in inference/models/yolo26/yolo26_instance_segmentation.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class YOLO26InstanceSegmentation(YOLOv11InstanceSegmentation):
    """YOLO26 Instance Segmentation model with end-to-end ONNX output.

    YOLO26 exports with NMS already applied, outputting:
    - predictions: (batch, num_detections, 38) where 38 = 6 + 32 mask coefficients
      Format: [x1, y1, x2, y2, confidence, class_index, mask_coeff_0, ..., mask_coeff_31]
    - protos: (batch, 32, H, W) mask prototypes
    """

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Performs inference on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Predictions and mask prototypes.
        """
        with self._session_lock:
            predictions, protos = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )

        return predictions, protos

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, np.ndarray],
        preprocess_return_metadata: PreprocessReturnMetadata,
        confidence: float = DEFAULT_CONFIDENCE,
        mask_decode_mode: str = DEFAULT_MASK_DECODE_MODE,
        tradeoff_factor: float = DEFAULT_TRADEOFF_FACTOR,
        **kwargs,
    ) -> Union[
        InstanceSegmentationInferenceResponse,
        List[InstanceSegmentationInferenceResponse],
    ]:
        """Postprocesses the instance segmentation predictions.

        YOLO26 predictions come with NMS already applied, so we just need to:
        1. Filter by confidence
        2. Decode masks
        3. Format response
        """
        predictions, protos = predictions
        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preprocess_return_metadata["img_dims"]
        img_in_shape = preprocess_return_metadata["im_shape"]

        # Filter by confidence and process each batch
        masks = []
        filtered_predictions = []

        for batch_idx, batch_preds in enumerate(predictions):
            # Filter by confidence (conf is at index 4)
            keep = batch_preds[:, 4] > confidence
            batch_preds = batch_preds[keep]
            filtered_predictions.append(batch_preds)

            if batch_preds.size == 0:
                masks.append([])
                continue

            # Get mask coefficients (starting at index 6)
            mask_coeffs = batch_preds[:, 6:]
            boxes = batch_preds[:, :4]
            proto = protos[batch_idx]
            img_dim = img_dims[batch_idx]

            # Decode masks based on mode
            if mask_decode_mode == "accurate":
                batch_masks = process_mask_accurate(
                    proto, mask_coeffs, boxes, img_in_shape[2:]
                )
                output_mask_shape = img_in_shape[2:]
            elif mask_decode_mode == "tradeoff":
                if not 0 <= tradeoff_factor <= 1:
                    raise InvalidMaskDecodeArgument(
                        f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]"
                    )
                batch_masks = process_mask_tradeoff(
                    proto, mask_coeffs, boxes, img_in_shape[2:], tradeoff_factor
                )
                output_mask_shape = batch_masks.shape[1:]
            elif mask_decode_mode == "fast":
                batch_masks = process_mask_fast(
                    proto, mask_coeffs, boxes, img_in_shape[2:]
                )
                output_mask_shape = batch_masks.shape[1:]
            else:
                raise InvalidMaskDecodeArgument(
                    f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']"
                )

            # Convert masks to polygons
            polys = masks2poly(batch_masks)

            # Post-process bounding boxes
            batch_preds[:, :4] = post_process_bboxes(
                [boxes],
                infer_shape,
                [img_dim],
                self.preproc,
                resize_method=self.resize_method,
                disable_preproc_static_crop=preprocess_return_metadata[
                    "disable_preproc_static_crop"
                ],
            )[0]

            # Post-process polygons
            polys = post_process_polygons(
                img_dim,
                polys,
                output_mask_shape,
                self.preproc,
                resize_method=self.resize_method,
            )
            masks.append(polys)
            filtered_predictions[batch_idx] = batch_preds

        return self.make_response(filtered_predictions, masks, img_dims, **kwargs)

    def make_response(
        self,
        predictions: List[np.ndarray],
        masks: List[List[List[Tuple[float, float]]]],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[InstanceSegmentationInferenceResponse]:
        """Constructs instance segmentation response objects.

        YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, mask_coeffs...]
        """
        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]

        responses = []
        for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)):
            preds_list = []
            for pred, mask in zip(batch_predictions, batch_masks):
                class_idx = int(pred[5])  # class index is at position 5
                if class_filter and self.class_names[class_idx] not in class_filter:
                    continue
                preds_list.append(
                    InstanceSegmentationPrediction(
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "points": [Point(x=point[0], y=point[1]) for point in mask],
                            "confidence": pred[4],
                            "class": self.class_names[class_idx],
                            "class_id": class_idx,
                        }
                    )
                )
            responses.append(
                InstanceSegmentationInferenceResponse(
                    predictions=preds_list,
                    image=InferenceResponseImage(
                        width=img_dims[ind][1], height=img_dims[ind][0]
                    ),
                )
            )
        return responses

    def validate_model_classes(self) -> None:
        pass
Functions
make_response
make_response(
    predictions,
    masks,
    img_dims,
    class_filter=None,
    *args,
    **kwargs
)

Constructs instance segmentation response objects.

YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, mask_coeffs...]

Source code in inference/models/yolo26/yolo26_instance_segmentation.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def make_response(
    self,
    predictions: List[np.ndarray],
    masks: List[List[List[Tuple[float, float]]]],
    img_dims: List[Tuple[int, int]],
    class_filter: Optional[List[str]] = None,
    *args,
    **kwargs,
) -> List[InstanceSegmentationInferenceResponse]:
    """Constructs instance segmentation response objects.

    YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, mask_coeffs...]
    """
    if isinstance(img_dims, dict) and "img_dims" in img_dims:
        img_dims = img_dims["img_dims"]

    responses = []
    for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)):
        preds_list = []
        for pred, mask in zip(batch_predictions, batch_masks):
            class_idx = int(pred[5])  # class index is at position 5
            if class_filter and self.class_names[class_idx] not in class_filter:
                continue
            preds_list.append(
                InstanceSegmentationPrediction(
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "points": [Point(x=point[0], y=point[1]) for point in mask],
                        "confidence": pred[4],
                        "class": self.class_names[class_idx],
                        "class_id": class_idx,
                    }
                )
            )
        responses.append(
            InstanceSegmentationInferenceResponse(
                predictions=preds_list,
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
        )
    return responses
postprocess
postprocess(
    predictions,
    preprocess_return_metadata,
    confidence=DEFAULT_CONFIDENCE,
    mask_decode_mode=DEFAULT_MASK_DECODE_MODE,
    tradeoff_factor=DEFAULT_TRADEOFF_FACTOR,
    **kwargs
)

Postprocesses the instance segmentation predictions.

YOLO26 predictions come with NMS already applied, so we just need to: 1. Filter by confidence 2. Decode masks 3. Format response

Source code in inference/models/yolo26/yolo26_instance_segmentation.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def postprocess(
    self,
    predictions: Tuple[np.ndarray, np.ndarray],
    preprocess_return_metadata: PreprocessReturnMetadata,
    confidence: float = DEFAULT_CONFIDENCE,
    mask_decode_mode: str = DEFAULT_MASK_DECODE_MODE,
    tradeoff_factor: float = DEFAULT_TRADEOFF_FACTOR,
    **kwargs,
) -> Union[
    InstanceSegmentationInferenceResponse,
    List[InstanceSegmentationInferenceResponse],
]:
    """Postprocesses the instance segmentation predictions.

    YOLO26 predictions come with NMS already applied, so we just need to:
    1. Filter by confidence
    2. Decode masks
    3. Format response
    """
    predictions, protos = predictions
    infer_shape = (self.img_size_h, self.img_size_w)
    img_dims = preprocess_return_metadata["img_dims"]
    img_in_shape = preprocess_return_metadata["im_shape"]

    # Filter by confidence and process each batch
    masks = []
    filtered_predictions = []

    for batch_idx, batch_preds in enumerate(predictions):
        # Filter by confidence (conf is at index 4)
        keep = batch_preds[:, 4] > confidence
        batch_preds = batch_preds[keep]
        filtered_predictions.append(batch_preds)

        if batch_preds.size == 0:
            masks.append([])
            continue

        # Get mask coefficients (starting at index 6)
        mask_coeffs = batch_preds[:, 6:]
        boxes = batch_preds[:, :4]
        proto = protos[batch_idx]
        img_dim = img_dims[batch_idx]

        # Decode masks based on mode
        if mask_decode_mode == "accurate":
            batch_masks = process_mask_accurate(
                proto, mask_coeffs, boxes, img_in_shape[2:]
            )
            output_mask_shape = img_in_shape[2:]
        elif mask_decode_mode == "tradeoff":
            if not 0 <= tradeoff_factor <= 1:
                raise InvalidMaskDecodeArgument(
                    f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]"
                )
            batch_masks = process_mask_tradeoff(
                proto, mask_coeffs, boxes, img_in_shape[2:], tradeoff_factor
            )
            output_mask_shape = batch_masks.shape[1:]
        elif mask_decode_mode == "fast":
            batch_masks = process_mask_fast(
                proto, mask_coeffs, boxes, img_in_shape[2:]
            )
            output_mask_shape = batch_masks.shape[1:]
        else:
            raise InvalidMaskDecodeArgument(
                f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']"
            )

        # Convert masks to polygons
        polys = masks2poly(batch_masks)

        # Post-process bounding boxes
        batch_preds[:, :4] = post_process_bboxes(
            [boxes],
            infer_shape,
            [img_dim],
            self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preprocess_return_metadata[
                "disable_preproc_static_crop"
            ],
        )[0]

        # Post-process polygons
        polys = post_process_polygons(
            img_dim,
            polys,
            output_mask_shape,
            self.preproc,
            resize_method=self.resize_method,
        )
        masks.append(polys)
        filtered_predictions[batch_idx] = batch_preds

    return self.make_response(filtered_predictions, masks, img_dims, **kwargs)
predict
predict(img_in, **kwargs)

Performs inference on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: Predictions and mask prototypes.

Source code in inference/models/yolo26/yolo26_instance_segmentation.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
    """Performs inference on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Predictions and mask prototypes.
    """
    with self._session_lock:
        predictions, protos = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )

    return predictions, protos

Functions

inference.models.yolo26.yolo26_keypoints_detection

Classes

YOLO26KeypointsDetection

Bases: YOLOv11KeypointsDetection

YOLO26 Keypoints Detection model with end-to-end ONNX output.

YOLO26 exports with NMS already applied, outputting: - predictions: (batch, num_detections, 57) for COCO pose (17 keypoints * 3 + 6) Format: [x1, y1, x2, y2, confidence, class_index, kp0_x, kp0_y, kp0_conf, ...]

Source code in inference/models/yolo26/yolo26_keypoints_detection.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class YOLO26KeypointsDetection(YOLOv11KeypointsDetection):
    """YOLO26 Keypoints Detection model with end-to-end ONNX output.

    YOLO26 exports with NMS already applied, outputting:
    - predictions: (batch, num_detections, 57) for COCO pose (17 keypoints * 3 + 6)
      Format: [x1, y1, x2, y2, confidence, class_index, kp0_x, kp0_y, kp0_conf, ...]
    """

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
        """Performs inference on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: Predictions with boxes, confidence, class, and keypoints.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]

        return (predictions,)

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preproc_return_metadata: PreprocessReturnMetadata,
        confidence: float = DEFAULT_CONFIDENCE,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Postprocesses the keypoints detection predictions.

        YOLO26 predictions come with NMS already applied, so we just need to:
        1. Filter by confidence
        2. Scale coordinates to original image size
        3. Format response
        """
        predictions = predictions[0]
        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preproc_return_metadata["img_dims"]

        # Number of keypoint values (x, y, conf per keypoint)
        num_keypoint_values = (
            predictions.shape[2] - 6
        )  # subtract boxes(4) + conf(1) + class(1)

        # Filter by confidence and process each batch
        filtered_predictions = []
        for batch_preds in predictions:
            # Filter by confidence (conf is at index 4)
            keep = batch_preds[:, 4] > confidence
            batch_preds = batch_preds[keep]
            filtered_predictions.append(batch_preds)

        # Post-process bounding boxes
        filtered_predictions = post_process_bboxes(
            predictions=filtered_predictions,
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )

        # Post-process keypoints
        filtered_predictions = post_process_keypoints(
            predictions=filtered_predictions,
            keypoints_start_index=6,  # keypoints start at index 6
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )

        return self.make_response(filtered_predictions, img_dims, **kwargs)

    def make_response(
        self,
        predictions: List[np.ndarray],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Constructs keypoints detection response objects.

        YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, keypoints...]
        """
        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]

        keypoint_confidence_threshold = 0.0
        if "request" in kwargs:
            keypoint_confidence_threshold = kwargs["request"].keypoint_confidence

        responses = []
        for ind, batch_predictions in enumerate(predictions):
            preds_list = []
            for pred in batch_predictions:
                class_idx = int(pred[5])  # class index is at position 5
                if class_filter and self.class_names[class_idx] not in class_filter:
                    continue

                # Keypoints start at index 6
                keypoints_data = pred[6:]

                preds_list.append(
                    KeypointsPrediction(
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "confidence": pred[4],
                            "class": self.class_names[class_idx],
                            "class_id": class_idx,
                            "keypoints": model_keypoints_to_response(
                                keypoints_metadata=self.keypoints_metadata,
                                keypoints=keypoints_data,
                                predicted_object_class_id=class_idx,
                                keypoint_confidence_threshold=keypoint_confidence_threshold,
                            ),
                        }
                    )
                )
            responses.append(
                KeypointsDetectionInferenceResponse(
                    predictions=preds_list,
                    image=InferenceResponseImage(
                        width=img_dims[ind][1], height=img_dims[ind][0]
                    ),
                )
            )
        return responses

    def validate_model_classes(self) -> None:
        pass
Functions
make_response
make_response(
    predictions,
    img_dims,
    class_filter=None,
    *args,
    **kwargs
)

Constructs keypoints detection response objects.

YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, keypoints...]

Source code in inference/models/yolo26/yolo26_keypoints_detection.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def make_response(
    self,
    predictions: List[np.ndarray],
    img_dims: List[Tuple[int, int]],
    class_filter: Optional[List[str]] = None,
    *args,
    **kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
    """Constructs keypoints detection response objects.

    YOLO26 prediction format: [x1, y1, x2, y2, conf, class_idx, keypoints...]
    """
    if isinstance(img_dims, dict) and "img_dims" in img_dims:
        img_dims = img_dims["img_dims"]

    keypoint_confidence_threshold = 0.0
    if "request" in kwargs:
        keypoint_confidence_threshold = kwargs["request"].keypoint_confidence

    responses = []
    for ind, batch_predictions in enumerate(predictions):
        preds_list = []
        for pred in batch_predictions:
            class_idx = int(pred[5])  # class index is at position 5
            if class_filter and self.class_names[class_idx] not in class_filter:
                continue

            # Keypoints start at index 6
            keypoints_data = pred[6:]

            preds_list.append(
                KeypointsPrediction(
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "confidence": pred[4],
                        "class": self.class_names[class_idx],
                        "class_id": class_idx,
                        "keypoints": model_keypoints_to_response(
                            keypoints_metadata=self.keypoints_metadata,
                            keypoints=keypoints_data,
                            predicted_object_class_id=class_idx,
                            keypoint_confidence_threshold=keypoint_confidence_threshold,
                        ),
                    }
                )
            )
        responses.append(
            KeypointsDetectionInferenceResponse(
                predictions=preds_list,
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
        )
    return responses
postprocess
postprocess(
    predictions,
    preproc_return_metadata,
    confidence=DEFAULT_CONFIDENCE,
    **kwargs
)

Postprocesses the keypoints detection predictions.

YOLO26 predictions come with NMS already applied, so we just need to: 1. Filter by confidence 2. Scale coordinates to original image size 3. Format response

Source code in inference/models/yolo26/yolo26_keypoints_detection.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def postprocess(
    self,
    predictions: Tuple[np.ndarray],
    preproc_return_metadata: PreprocessReturnMetadata,
    confidence: float = DEFAULT_CONFIDENCE,
    **kwargs,
) -> List[KeypointsDetectionInferenceResponse]:
    """Postprocesses the keypoints detection predictions.

    YOLO26 predictions come with NMS already applied, so we just need to:
    1. Filter by confidence
    2. Scale coordinates to original image size
    3. Format response
    """
    predictions = predictions[0]
    infer_shape = (self.img_size_h, self.img_size_w)
    img_dims = preproc_return_metadata["img_dims"]

    # Number of keypoint values (x, y, conf per keypoint)
    num_keypoint_values = (
        predictions.shape[2] - 6
    )  # subtract boxes(4) + conf(1) + class(1)

    # Filter by confidence and process each batch
    filtered_predictions = []
    for batch_preds in predictions:
        # Filter by confidence (conf is at index 4)
        keep = batch_preds[:, 4] > confidence
        batch_preds = batch_preds[keep]
        filtered_predictions.append(batch_preds)

    # Post-process bounding boxes
    filtered_predictions = post_process_bboxes(
        predictions=filtered_predictions,
        infer_shape=infer_shape,
        img_dims=img_dims,
        preproc=self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )

    # Post-process keypoints
    filtered_predictions = post_process_keypoints(
        predictions=filtered_predictions,
        keypoints_start_index=6,  # keypoints start at index 6
        infer_shape=infer_shape,
        img_dims=img_dims,
        preproc=self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )

    return self.make_response(filtered_predictions, img_dims, **kwargs)
predict
predict(img_in, **kwargs)

Performs inference on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ...]

Tuple[np.ndarray]: Predictions with boxes, confidence, class, and keypoints.

Source code in inference/models/yolo26/yolo26_keypoints_detection.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
    """Performs inference on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: Predictions with boxes, confidence, class, and keypoints.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]

    return (predictions,)

Functions

models/yolo_world

inference.models.yolo_world.yolo_world

Classes

YOLOWorld

Bases: RoboflowCoreModel

YOLO-World class for zero-shot object detection.

Attributes:

Name Type Description
model

The YOLO-World model.

Source code in inference/models/yolo_world/yolo_world.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class YOLOWorld(RoboflowCoreModel):
    """YOLO-World class for zero-shot object detection.

    Attributes:
        model: The YOLO-World model.
    """

    task_type = "object-detection"

    def __init__(self, *args, model_id="yolo_world/l", **kwargs):
        """Initializes the YOLO-World model.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """

        super().__init__(*args, model_id=model_id, **kwargs)

        self.model = YOLO(self.cache_file("yolo-world.pt"))
        logger.debug("Loading CLIP ViT-B/32")
        clip_model = Clip(model_id="clip/ViT-B-32")
        logger.debug("CLIP loaded")
        self.clip_model = clip_model
        self.class_names = None
        self._state_lock = Lock()

    def preproc_image(self, image: Any):
        """Preprocesses an image.

        Args:
            image (InferenceRequestImage): The image to preprocess.

        Returns:
            np.array: The preprocessed image.
        """
        np_image = load_image_rgb(image)
        return np_image[:, :, ::-1]

    def infer_from_request(
        self,
        request: YOLOWorldInferenceRequest,
    ) -> ObjectDetectionInferenceResponse:
        """
        Perform inference based on the details provided in the request, and return the associated responses.
        """
        with self._state_lock:
            return self.infer(**request.dict())

    def infer(
        self,
        image: Any = None,
        text: list = None,
        confidence: float = DEFAULT_CONFIDENCE,
        max_detections: Optional[int] = DEFAUlT_MAX_DETECTIONS,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
        **kwargs,
    ):
        """
        Run inference on a provided image.

        Args:
            image - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            GroundingDINOInferenceRequest: The inference response.
        """
        logger.debug("YOLOWorld infer() - image preprocessing.")
        t1 = perf_counter()
        image = self.preproc_image(image)
        logger.debug("YOLOWorld infer() - image ready.")
        img_dims = image.shape

        if text is not None and text != self.class_names:
            logger.debug("YOLOWorld infer() - classes embeddings are calculated.")
            self.set_classes(text)
            logger.debug("YOLOWorld infer() - classes embeddings are ready.")
        if self.class_names is None:
            raise ValueError(
                "Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`."
            )
        logger.debug("YOLOWorld infer() - prediction starts.")
        results = self.model.predict(
            image,
            conf=confidence,
            verbose=False,
        )[0]
        logger.debug("YOLOWorld infer() - predictions ready.")
        t2 = perf_counter() - t1

        logger.debug("YOLOWorld infer() - post-processing starting")
        if len(results) > 0:
            bbox_array = np.array([box.xywh.tolist()[0] for box in results.boxes])
            conf_array = np.array([[float(box.conf)] for box in results.boxes])
            cls_array = np.array(
                [
                    self.get_cls_conf_array(
                        max_class_id=int(box.cls),
                        max_class_confidence=float(box.conf),
                    )
                    for box in results.boxes
                ]
            )

            pred_array = np.concatenate([bbox_array, conf_array, cls_array], axis=1)
            pred_array = np.expand_dims(pred_array, axis=0)
            pred_array = w_np_non_max_suppression(
                pred_array,
                conf_thresh=confidence,
                iou_thresh=iou_threshold,
                class_agnostic=class_agnostic_nms,
                max_detections=max_detections,
                max_candidate_detections=max_candidates,
                box_format="xywh",
            )[0]
        else:
            pred_array = []
        predictions = []
        logger.debug("YOLOWorld infer() - post-processing done")
        for i, pred in enumerate(pred_array):
            predictions.append(
                ObjectDetectionPrediction(
                    **{
                        "x": (pred[0] + pred[2]) / 2,
                        "y": (pred[1] + pred[3]) / 2,
                        "width": pred[2] - pred[0],
                        "height": pred[3] - pred[1],
                        "confidence": pred[4],
                        "class": self.class_names[int(pred[6])],
                        "class_id": int(pred[6]),
                    }
                )
            )

        responses = ObjectDetectionInferenceResponse(
            predictions=predictions,
            image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
            time=t2,
        )
        return responses

    def set_classes(self, text: list):
        """Set the class names for the model.

        Args:
            text (list): The class names.
        """
        class_names_to_calculate_embeddings = []
        classes_embeddings = {}
        for class_name in text:
            class_name_hash = f"clip-embedding:{get_text_hash(text=class_name)}"
            embedding_for_class = cache.get_numpy(class_name_hash)
            if embedding_for_class is not None:
                logger.debug(f"Cache hit for class: {class_name}")
                classes_embeddings[class_name] = embedding_for_class
            else:
                logger.debug(f"Cache miss for class: {class_name}")
                class_names_to_calculate_embeddings.append(class_name)
        if len(class_names_to_calculate_embeddings) > 0:
            logger.debug(
                f"Calculating CLIP embeddings for {len(class_names_to_calculate_embeddings)} class names"
            )
            cache_miss_embeddings = self.clip_model.embed_text(
                text=class_names_to_calculate_embeddings
            )
        else:
            cache_miss_embeddings = []
        for missing_class_name, calculated_embedding in zip(
            class_names_to_calculate_embeddings, cache_miss_embeddings
        ):
            classes_embeddings[missing_class_name] = calculated_embedding
            missing_class_name_hash = (
                f"clip-embedding:{get_text_hash(text=missing_class_name)}"
            )
            cache.set_numpy(  # caching vectors of shape (512,)
                missing_class_name_hash,
                calculated_embedding,
                expire=EMBEDDINGS_EXPIRE_TIMEOUT,
            )
        embeddings_in_order = np.stack(
            [classes_embeddings[class_name] for class_name in text], axis=0
        )
        txt_feats = torch.from_numpy(embeddings_in_order)
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
        self.model.model.txt_feats = txt_feats.reshape(
            -1, len(text), txt_feats.shape[-1]
        ).detach()
        self.model.model.model[-1].nc = len(text)
        self.class_names = text

    def get_infer_bucket_file_list(self) -> list:
        """Get the list of required files for inference.

        Returns:
            list: A list of required files for inference, e.g., ["model.pt"].
        """
        return ["yolo-world.pt"]

    def get_cls_conf_array(
        self, max_class_id: int, max_class_confidence: float
    ) -> List[float]:
        arr = [0.0] * len(self.class_names)
        arr[max_class_id] = max_class_confidence
        return arr
Functions
__init__
__init__(*args, model_id='yolo_world/l', **kwargs)

Initializes the YOLO-World model.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in inference/models/yolo_world/yolo_world.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(self, *args, model_id="yolo_world/l", **kwargs):
    """Initializes the YOLO-World model.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """

    super().__init__(*args, model_id=model_id, **kwargs)

    self.model = YOLO(self.cache_file("yolo-world.pt"))
    logger.debug("Loading CLIP ViT-B/32")
    clip_model = Clip(model_id="clip/ViT-B-32")
    logger.debug("CLIP loaded")
    self.clip_model = clip_model
    self.class_names = None
    self._state_lock = Lock()
get_infer_bucket_file_list
get_infer_bucket_file_list()

Get the list of required files for inference.

Returns:

Name Type Description
list list

A list of required files for inference, e.g., ["model.pt"].

Source code in inference/models/yolo_world/yolo_world.py
230
231
232
233
234
235
236
def get_infer_bucket_file_list(self) -> list:
    """Get the list of required files for inference.

    Returns:
        list: A list of required files for inference, e.g., ["model.pt"].
    """
    return ["yolo-world.pt"]
infer
infer(
    image=None,
    text=None,
    confidence=DEFAULT_CONFIDENCE,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    iou_threshold=DEFAULT_IOU_THRESH,
    max_candidates=DEFAULT_MAX_CANDIDATES,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    **kwargs
)

Run inference on a provided image.

Parameters:

Name Type Description Default
class_filter Optional[List[str]]

A list of class names to filter, if provided.

required

Returns:

Name Type Description
GroundingDINOInferenceRequest

The inference response.

Source code in inference/models/yolo_world/yolo_world.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def infer(
    self,
    image: Any = None,
    text: list = None,
    confidence: float = DEFAULT_CONFIDENCE,
    max_detections: Optional[int] = DEFAUlT_MAX_DETECTIONS,
    iou_threshold: float = DEFAULT_IOU_THRESH,
    max_candidates: int = DEFAULT_MAX_CANDIDATES,
    class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
    **kwargs,
):
    """
    Run inference on a provided image.

    Args:
        image - can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
        class_filter (Optional[List[str]]): A list of class names to filter, if provided.

    Returns:
        GroundingDINOInferenceRequest: The inference response.
    """
    logger.debug("YOLOWorld infer() - image preprocessing.")
    t1 = perf_counter()
    image = self.preproc_image(image)
    logger.debug("YOLOWorld infer() - image ready.")
    img_dims = image.shape

    if text is not None and text != self.class_names:
        logger.debug("YOLOWorld infer() - classes embeddings are calculated.")
        self.set_classes(text)
        logger.debug("YOLOWorld infer() - classes embeddings are ready.")
    if self.class_names is None:
        raise ValueError(
            "Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`."
        )
    logger.debug("YOLOWorld infer() - prediction starts.")
    results = self.model.predict(
        image,
        conf=confidence,
        verbose=False,
    )[0]
    logger.debug("YOLOWorld infer() - predictions ready.")
    t2 = perf_counter() - t1

    logger.debug("YOLOWorld infer() - post-processing starting")
    if len(results) > 0:
        bbox_array = np.array([box.xywh.tolist()[0] for box in results.boxes])
        conf_array = np.array([[float(box.conf)] for box in results.boxes])
        cls_array = np.array(
            [
                self.get_cls_conf_array(
                    max_class_id=int(box.cls),
                    max_class_confidence=float(box.conf),
                )
                for box in results.boxes
            ]
        )

        pred_array = np.concatenate([bbox_array, conf_array, cls_array], axis=1)
        pred_array = np.expand_dims(pred_array, axis=0)
        pred_array = w_np_non_max_suppression(
            pred_array,
            conf_thresh=confidence,
            iou_thresh=iou_threshold,
            class_agnostic=class_agnostic_nms,
            max_detections=max_detections,
            max_candidate_detections=max_candidates,
            box_format="xywh",
        )[0]
    else:
        pred_array = []
    predictions = []
    logger.debug("YOLOWorld infer() - post-processing done")
    for i, pred in enumerate(pred_array):
        predictions.append(
            ObjectDetectionPrediction(
                **{
                    "x": (pred[0] + pred[2]) / 2,
                    "y": (pred[1] + pred[3]) / 2,
                    "width": pred[2] - pred[0],
                    "height": pred[3] - pred[1],
                    "confidence": pred[4],
                    "class": self.class_names[int(pred[6])],
                    "class_id": int(pred[6]),
                }
            )
        )

    responses = ObjectDetectionInferenceResponse(
        predictions=predictions,
        image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
        time=t2,
    )
    return responses
infer_from_request
infer_from_request(request)

Perform inference based on the details provided in the request, and return the associated responses.

Source code in inference/models/yolo_world/yolo_world.py
76
77
78
79
80
81
82
83
84
def infer_from_request(
    self,
    request: YOLOWorldInferenceRequest,
) -> ObjectDetectionInferenceResponse:
    """
    Perform inference based on the details provided in the request, and return the associated responses.
    """
    with self._state_lock:
        return self.infer(**request.dict())
preproc_image
preproc_image(image)

Preprocesses an image.

Parameters:

Name Type Description Default
image InferenceRequestImage

The image to preprocess.

required

Returns:

Type Description

np.array: The preprocessed image.

Source code in inference/models/yolo_world/yolo_world.py
64
65
66
67
68
69
70
71
72
73
74
def preproc_image(self, image: Any):
    """Preprocesses an image.

    Args:
        image (InferenceRequestImage): The image to preprocess.

    Returns:
        np.array: The preprocessed image.
    """
    np_image = load_image_rgb(image)
    return np_image[:, :, ::-1]
set_classes
set_classes(text)

Set the class names for the model.

Parameters:

Name Type Description Default
text list

The class names.

required
Source code in inference/models/yolo_world/yolo_world.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def set_classes(self, text: list):
    """Set the class names for the model.

    Args:
        text (list): The class names.
    """
    class_names_to_calculate_embeddings = []
    classes_embeddings = {}
    for class_name in text:
        class_name_hash = f"clip-embedding:{get_text_hash(text=class_name)}"
        embedding_for_class = cache.get_numpy(class_name_hash)
        if embedding_for_class is not None:
            logger.debug(f"Cache hit for class: {class_name}")
            classes_embeddings[class_name] = embedding_for_class
        else:
            logger.debug(f"Cache miss for class: {class_name}")
            class_names_to_calculate_embeddings.append(class_name)
    if len(class_names_to_calculate_embeddings) > 0:
        logger.debug(
            f"Calculating CLIP embeddings for {len(class_names_to_calculate_embeddings)} class names"
        )
        cache_miss_embeddings = self.clip_model.embed_text(
            text=class_names_to_calculate_embeddings
        )
    else:
        cache_miss_embeddings = []
    for missing_class_name, calculated_embedding in zip(
        class_names_to_calculate_embeddings, cache_miss_embeddings
    ):
        classes_embeddings[missing_class_name] = calculated_embedding
        missing_class_name_hash = (
            f"clip-embedding:{get_text_hash(text=missing_class_name)}"
        )
        cache.set_numpy(  # caching vectors of shape (512,)
            missing_class_name_hash,
            calculated_embedding,
            expire=EMBEDDINGS_EXPIRE_TIMEOUT,
        )
    embeddings_in_order = np.stack(
        [classes_embeddings[class_name] for class_name in text], axis=0
    )
    txt_feats = torch.from_numpy(embeddings_in_order)
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    self.model.model.txt_feats = txt_feats.reshape(
        -1, len(text), txt_feats.shape[-1]
    ).detach()
    self.model.model.model[-1].nc = len(text)
    self.class_names = text

Functions

models/yolov10

inference.models.yolov10.yolov10_object_detection

Classes

YOLOv10ObjectDetection

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection model (Implements an object detection specific infer method).

This class is responsible for performing object detection using the YOLOv10 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs object detection on the given image using the ONNX session.

Source code in inference/models/yolov10/yolov10_object_detection.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class YOLOv10ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model (Implements an object detection specific infer method).

    This class is responsible for performing object detection using the YOLOv10 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs object detection on the given image using the ONNX session.
    """

    box_format = "xyxy"

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv10 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]

        return (predictions,)

    def postprocess(
        self,
        predictions: Tuple[np.ndarray, ...],
        preproc_return_metadata: PreprocessReturnMetadata,
        confidence: float = DEFAULT_CONFIDENCE,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        **kwargs,
    ) -> List[ObjectDetectionInferenceResponse]:
        """Postprocesses the object detection predictions.

        Args:
            predictions (np.ndarray): Raw predictions from the model.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            confidence (float): Confidence threshold for filtering detections. Default is 0.5.
            max_detections (int): Maximum number of final detections. Default is 300.

        Returns:
            List[ObjectDetectionInferenceResponse]: The post-processed predictions.
        """
        predictions = predictions[0]
        predictions = np.append(predictions, predictions[..., 5:], axis=-1)
        predictions[..., 5] = predictions[..., 4]

        mask = predictions[..., 4] > confidence
        predictions = [
            p[mask[idx]][:max_detections] for idx, p in enumerate(predictions)
        ]

        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preproc_return_metadata["img_dims"]
        predictions = post_process_bboxes(
            predictions,
            infer_shape,
            img_dims,
            self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        return self.make_response(predictions, img_dims, **kwargs)

    def validate_model_classes(self) -> None:
        pass
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv10 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
postprocess
postprocess(
    predictions,
    preproc_return_metadata,
    confidence=DEFAULT_CONFIDENCE,
    max_detections=DEFAUlT_MAX_DETECTIONS,
    **kwargs
)

Postprocesses the object detection predictions.

Parameters:

Name Type Description Default
predictions ndarray

Raw predictions from the model.

required
img_dims List[Tuple[int, int]]

Dimensions of the images.

required
confidence float

Confidence threshold for filtering detections. Default is 0.5.

DEFAULT_CONFIDENCE
max_detections int

Maximum number of final detections. Default is 300.

DEFAUlT_MAX_DETECTIONS

Returns:

Type Description
List[ObjectDetectionInferenceResponse]

List[ObjectDetectionInferenceResponse]: The post-processed predictions.

Source code in inference/models/yolov10/yolov10_object_detection.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def postprocess(
    self,
    predictions: Tuple[np.ndarray, ...],
    preproc_return_metadata: PreprocessReturnMetadata,
    confidence: float = DEFAULT_CONFIDENCE,
    max_detections: int = DEFAUlT_MAX_DETECTIONS,
    **kwargs,
) -> List[ObjectDetectionInferenceResponse]:
    """Postprocesses the object detection predictions.

    Args:
        predictions (np.ndarray): Raw predictions from the model.
        img_dims (List[Tuple[int, int]]): Dimensions of the images.
        confidence (float): Confidence threshold for filtering detections. Default is 0.5.
        max_detections (int): Maximum number of final detections. Default is 300.

    Returns:
        List[ObjectDetectionInferenceResponse]: The post-processed predictions.
    """
    predictions = predictions[0]
    predictions = np.append(predictions, predictions[..., 5:], axis=-1)
    predictions[..., 5] = predictions[..., 4]

    mask = predictions[..., 4] > confidence
    predictions = [
        p[mask[idx]][:max_detections] for idx, p in enumerate(predictions)
    ]

    infer_shape = (self.img_size_h, self.img_size_w)
    img_dims = preproc_return_metadata["img_dims"]
    predictions = post_process_bboxes(
        predictions,
        infer_shape,
        img_dims,
        self.preproc,
        resize_method=self.resize_method,
        disable_preproc_static_crop=preproc_return_metadata[
            "disable_preproc_static_crop"
        ],
    )
    return self.make_response(predictions, img_dims, **kwargs)
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.

Source code in inference/models/yolov10/yolov10_object_detection.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]

    return (predictions,)

Functions

models/yolov5

inference.models.yolov5.yolov5_instance_segmentation

Classes

YOLOv5InstanceSegmentation

Bases: InstanceSegmentationBaseOnnxRoboflowInferenceModel

YOLOv5 Instance Segmentation ONNX Inference Model.

This class is responsible for performing instance segmentation using the YOLOv5 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Source code in inference/models/yolov5/yolov5_instance_segmentation.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class YOLOv5InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel):
    """YOLOv5 Instance Segmentation ONNX Inference Model.

    This class is responsible for performing instance segmentation using the YOLOv5 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv5 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "yolov5s_weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Performs inference on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        return predictions[0], predictions[1]
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv5 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
predict
predict(img_in, **kwargs)

Performs inference on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions.

Source code in inference/models/yolov5/yolov5_instance_segmentation.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
    """Performs inference on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )
    return predictions[0], predictions[1]

inference.models.yolov5.yolov5_object_detection

Classes

YOLOv5ObjectDetection

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection model (Implements an object detection specific infer method).

This class is responsible for performing object detection using the YOLOv5 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Source code in inference/models/yolov5/yolov5_object_detection.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class YOLOv5ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model (Implements an object detection specific infer method).

    This class is responsible for performing object detection using the YOLOv5 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv5 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "yolov5s_weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]
        return (predictions,)
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv5 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions.

Source code in inference/models/yolov5/yolov5_object_detection.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]
    return (predictions,)

models/yolov7

inference.models.yolov7.yolov7_instance_segmentation

Classes

YOLOv7InstanceSegmentation

Bases: InstanceSegmentationBaseOnnxRoboflowInferenceModel

YOLOv7 Instance Segmentation ONNX Inference Model.

This class is responsible for performing instance segmentation using the YOLOv7 model with ONNX runtime.

Methods:

Name Description
predict

Performs inference on the given image using the ONNX session.

Source code in inference/models/yolov7/yolov7_instance_segmentation.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class YOLOv7InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel):
    """YOLOv7 Instance Segmentation ONNX Inference Model.

    This class is responsible for performing instance segmentation using the YOLOv7 model
    with ONNX runtime.

    Methods:
        predict: Performs inference on the given image using the ONNX session.
    """

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Performs inference on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        protos = predictions[4]
        predictions = predictions[0]
        return predictions, protos
Functions
predict
predict(img_in, **kwargs)

Performs inference on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos.

Source code in inference/models/yolov7/yolov7_instance_segmentation.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
    """Performs inference on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )
    protos = predictions[4]
    predictions = predictions[0]
    return predictions, protos

models/yolov8

inference.models.yolov8.yolov8_instance_segmentation

Classes

YOLOv8InstanceSegmentation

Bases: InstanceSegmentationBaseOnnxRoboflowInferenceModel

YOLOv8 Instance Segmentation ONNX Inference Model.

This class is responsible for performing instance segmentation using the YOLOv8 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs inference on the given image using the ONNX session.

Source code in inference/models/yolov8/yolov8_instance_segmentation.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class YOLOv8InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel):
    """YOLOv8 Instance Segmentation ONNX Inference Model.

    This class is responsible for performing instance segmentation using the YOLOv8 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs inference on the given image using the ONNX session.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv8 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
        """Performs inference on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos. The predictions include boxes, confidence scores, class confidence scores, and masks.
        """
        with self._session_lock:
            predictions, protos = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )
        predictions = predictions.transpose(0, 2, 1)
        boxes = predictions[:, :, :4]
        class_confs = predictions[:, :, 4:-32]
        confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
        masks = predictions[:, :, -32:]
        predictions = np.concatenate([boxes, confs, class_confs, masks], axis=2)
        return predictions, protos
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv8 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
predict
predict(img_in, **kwargs)

Performs inference on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos. The predictions include boxes, confidence scores, class confidence scores, and masks.

Source code in inference/models/yolov8/yolov8_instance_segmentation.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
    """Performs inference on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos. The predictions include boxes, confidence scores, class confidence scores, and masks.
    """
    with self._session_lock:
        predictions, protos = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )
    predictions = predictions.transpose(0, 2, 1)
    boxes = predictions[:, :, :4]
    class_confs = predictions[:, :, 4:-32]
    confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
    masks = predictions[:, :, -32:]
    predictions = np.concatenate([boxes, confs, class_confs, masks], axis=2)
    return predictions, protos

inference.models.yolov8.yolov8_keypoints_detection

Classes

YOLOv8KeypointsDetection

Bases: KeypointsDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX keypoints detection model (Implements an object detection specific infer method).

This class is responsible for performing keypoints detection using the YOLOv8 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs object detection on the given image using the ONNX session.

Source code in inference/models/yolov8/yolov8_keypoints_detection.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class YOLOv8KeypointsDetection(KeypointsDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX keypoints detection model (Implements an object detection specific infer method).

    This class is responsible for performing keypoints detection using the YOLOv8 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs object detection on the given image using the ONNX session.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv8 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
        """Performs object detection on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]
        predictions = predictions.transpose(0, 2, 1)
        boxes = predictions[:, :, :4]
        number_of_classes = len(self.get_class_names)
        class_confs = predictions[:, :, 4 : 4 + number_of_classes]
        keypoints_detections = predictions[:, :, 4 + number_of_classes :]
        confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
        bboxes_predictions = np.concatenate(
            [boxes, confs, class_confs, keypoints_detections], axis=2
        )
        return (bboxes_predictions,)

    def keypoints_count(self) -> int:
        """Returns the number of keypoints in the model."""
        if self.keypoints_metadata is None:
            raise ModelArtefactError("Keypoints metadata not available.")
        return superset_keypoints_count(self.keypoints_metadata)
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv8 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
keypoints_count
keypoints_count()

Returns the number of keypoints in the model.

Source code in inference/models/yolov8/yolov8_keypoints_detection.py
59
60
61
62
63
def keypoints_count(self) -> int:
    """Returns the number of keypoints in the model."""
    if self.keypoints_metadata is None:
        raise ModelArtefactError("Keypoints metadata not available.")
    return superset_keypoints_count(self.keypoints_metadata)
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray, ...]

Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.

Source code in inference/models/yolov8/yolov8_keypoints_detection.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]:
    """Performs object detection on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]
    predictions = predictions.transpose(0, 2, 1)
    boxes = predictions[:, :, :4]
    number_of_classes = len(self.get_class_names)
    class_confs = predictions[:, :, 4 : 4 + number_of_classes]
    keypoints_detections = predictions[:, :, 4 + number_of_classes :]
    confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
    bboxes_predictions = np.concatenate(
        [boxes, confs, class_confs, keypoints_detections], axis=2
    )
    return (bboxes_predictions,)

Functions

inference.models.yolov8.yolov8_object_detection

Classes

YOLOv8ObjectDetection

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection model (Implements an object detection specific infer method).

This class is responsible for performing object detection using the YOLOv8 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Methods:

Name Description
predict

Performs object detection on the given image using the ONNX session.

Source code in inference/models/yolov8/yolov8_object_detection.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class YOLOv8ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model (Implements an object detection specific infer method).

    This class is responsible for performing object detection using the YOLOv8 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.

    Methods:
        predict: Performs object detection on the given image using the ONNX session.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv8 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
        """
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]
        predictions = predictions.transpose(0, 2, 1)
        boxes = predictions[:, :, :4]
        class_confs = predictions[:, :, 4:]
        confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
        predictions = np.concatenate([boxes, confs, class_confs], axis=2)
        return (predictions,)
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv8 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.

Source code in inference/models/yolov8/yolov8_object_detection.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores.
    """
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]
    predictions = predictions.transpose(0, 2, 1)
    boxes = predictions[:, :, :4]
    class_confs = predictions[:, :, 4:]
    confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
    predictions = np.concatenate([boxes, confs, class_confs], axis=2)
    return (predictions,)

models/yolov9

inference.models.yolov9.yolov9_object_detection

Classes

YOLOv9ObjectDetection

Bases: ObjectDetectionBaseOnnxRoboflowInferenceModel

Roboflow ONNX Object detection model (Implements an object detection specific infer method).

This class is responsible for performing object detection using the YOLOv9 model with ONNX runtime.

Attributes:

Name Type Description
weights_file str

Path to the ONNX weights file.

Source code in inference/models/yolov9/yolov9_object_detection.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class YOLOv9ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
    """Roboflow ONNX Object detection model (Implements an object detection specific infer method).

    This class is responsible for performing object detection using the YOLOv9 model
    with ONNX runtime.

    Attributes:
        weights_file (str): Path to the ONNX weights file.
    """

    @property
    def weights_file(self) -> str:
        """Gets the weights file for the YOLOv9 model.

        Returns:
            str: Path to the ONNX weights file.
        """
        return "weights.onnx"

    def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
        """Performs object detection on the given image using the ONNX session.

        Args:
            img_in (np.ndarray): Input image as a NumPy array.

        Returns:
            Tuple[np.ndarray]: NumPy array representing the predictions.
        """
        # (b x 8 x 8000)
        with self._session_lock:
            predictions = run_session_via_iobinding(
                self.onnx_session, self.input_name, img_in
            )[0]
        predictions = predictions.transpose(0, 2, 1)
        boxes = predictions[:, :, :4]
        class_confs = predictions[:, :, 4:]
        confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
        predictions = np.concatenate([boxes, confs, class_confs], axis=2)
        return (predictions,)
Attributes
weights_file property
weights_file

Gets the weights file for the YOLOv9 model.

Returns:

Name Type Description
str str

Path to the ONNX weights file.

Functions
predict
predict(img_in, **kwargs)

Performs object detection on the given image using the ONNX session.

Parameters:

Name Type Description Default
img_in ndarray

Input image as a NumPy array.

required

Returns:

Type Description
Tuple[ndarray]

Tuple[np.ndarray]: NumPy array representing the predictions.

Source code in inference/models/yolov9/yolov9_object_detection.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]:
    """Performs object detection on the given image using the ONNX session.

    Args:
        img_in (np.ndarray): Input image as a NumPy array.

    Returns:
        Tuple[np.ndarray]: NumPy array representing the predictions.
    """
    # (b x 8 x 8000)
    with self._session_lock:
        predictions = run_session_via_iobinding(
            self.onnx_session, self.input_name, img_in
        )[0]
    predictions = predictions.transpose(0, 2, 1)
    boxes = predictions[:, :, :4]
    class_confs = predictions[:, :, 4:]
    confs = np.expand_dims(np.max(class_confs, axis=2), axis=2)
    predictions = np.concatenate([boxes, confs, class_confs], axis=2)
    return (predictions,)

usage_tracking

Anonymous usage and telemetry reporting.

inference.usage_tracking.redis_queue

Classes

RedisQueue

Store and forget, keys with specified hash tag are handled by external service

Source code in inference/usage_tracking/redis_queue.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class RedisQueue:
    """
    Store and forget, keys with specified hash tag are handled by external service
    """

    def __init__(
        self,
        hash_tag: str = "UsageCollector",
        redis_cache: Optional[RedisCache] = None,
    ):
        # prefix must contain hash-tag to avoid CROSSLOT errors when using mget
        # hash-tag is common part of the key wrapped within '{}'
        # removing hash-tag will cause clients utilizing mget to fail
        self._prefix: str = f"{{{hash_tag}}}:{time.time()}:{uuid4().hex[:5]}"
        self._redis_cache: RedisCache = redis_cache or cache
        self._increment: int = 0
        self._lock: Lock = Lock()

    def put(self, payload: Any):
        if not isinstance(payload, str):
            try:
                payload = json.dumps(payload)
            except Exception as exc:
                logger.error("Failed to parse payload '%s' to JSON - %s", payload, exc)
                return
        with self._lock:
            try:
                self._increment += 1
                redis_key = f"{self._prefix}:{self._increment}"
                # https://redis.io/docs/latest/develop/interact/transactions/
                redis_pipeline = self._redis_cache.client.pipeline()
                redis_pipeline.set(
                    name=redis_key,
                    value=payload,
                )
                redis_pipeline.zadd(
                    name="UsageCollector",
                    mapping={redis_key: time.time()},
                )
                results = redis_pipeline.execute()
                if not all(results):
                    # TODO: partial insert, retry
                    logger.error(
                        "Failed to store payload and sorted set (partial insert): %s",
                        results,
                    )
            except Exception as exc:
                logger.error("Failed to store usage records '%s', %s", payload, exc)

    @staticmethod
    def full() -> bool:
        return False

    def empty(self) -> bool:
        return True

    def get_nowait(self) -> List[Dict[str, Any]]:
        return []