Skip to content

Roboflow

RoboflowModelRegistry

Bases: ModelRegistry

A Roboflow-specific model registry which gets the model type using the model id, then returns a model class based on the model type.

Source code in inference/core/registries/roboflow.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class RoboflowModelRegistry(ModelRegistry):
    """A Roboflow-specific model registry which gets the model type using the model id,
    then returns a model class based on the model type.
    """

    def get_model(self, model_id: ModelID, api_key: str) -> Model:
        """Returns the model class based on the given model id and API key.

        Args:
            model_id (str): The ID of the model to be retrieved.
            api_key (str): The API key used to authenticate.

        Returns:
            Model: The model class corresponding to the given model ID and type.

        Raises:
            ModelNotRecognisedError: If the model type is not supported or found.
        """
        model_type = get_model_type(model_id, api_key)
        logger.debug(f"Model type: {model_type}")
        if model_type not in self.registry_dict:
            raise ModelNotRecognisedError(f"Model type not supported: {model_type}")
        return self.registry_dict[model_type]

get_model(model_id, api_key)

Returns the model class based on the given model id and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model to be retrieved.

required
api_key str

The API key used to authenticate.

required

Returns:

Name Type Description
Model Model

The model class corresponding to the given model ID and type.

Raises:

Type Description
ModelNotRecognisedError

If the model type is not supported or found.

Source code in inference/core/registries/roboflow.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def get_model(self, model_id: ModelID, api_key: str) -> Model:
    """Returns the model class based on the given model id and API key.

    Args:
        model_id (str): The ID of the model to be retrieved.
        api_key (str): The API key used to authenticate.

    Returns:
        Model: The model class corresponding to the given model ID and type.

    Raises:
        ModelNotRecognisedError: If the model type is not supported or found.
    """
    model_type = get_model_type(model_id, api_key)
    logger.debug(f"Model type: {model_type}")
    if model_type not in self.registry_dict:
        raise ModelNotRecognisedError(f"Model type not supported: {model_type}")
    return self.registry_dict[model_type]

get_model_type(model_id, api_key=None)

Retrieves the model type based on the given model ID and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model.

required
api_key str

The API key used to authenticate.

None

Returns:

Name Type Description
tuple Tuple[TaskType, ModelType]

The project task type and the model type.

Raises:

Type Description
WorkspaceLoadError

If the workspace could not be loaded or if the API key is invalid.

DatasetLoadError

If the dataset could not be loaded due to invalid ID, workspace ID or version ID.

MissingDefaultModelError

If default model is not configured and API does not provide this info

MalformedRoboflowAPIResponseError

Roboflow API responds in invalid format.

Source code in inference/core/registries/roboflow.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def get_model_type(
    model_id: ModelID,
    api_key: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
    """Retrieves the model type based on the given model ID and API key.

    Args:
        model_id (str): The ID of the model.
        api_key (str): The API key used to authenticate.

    Returns:
        tuple: The project task type and the model type.

    Raises:
        WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid.
        DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID.
        MissingDefaultModelError: If default model is not configured and API does not provide this info
        MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format.
    """
    model_id = resolve_roboflow_model_alias(model_id=model_id)
    dataset_id, version_id = get_model_id_chunks(model_id=model_id)
    if dataset_id in GENERIC_MODELS:
        logger.debug(f"Loading generic model: {dataset_id}.")
        return GENERIC_MODELS[dataset_id]

    if MODELS_CACHE_AUTH_ENABLED:
        if not _check_if_api_key_has_access_to_model(
            api_key=api_key, model_id=model_id
        ):
            raise RoboflowAPINotAuthorizedError(
                f"API key {api_key} does not have access to model {model_id}"
            )

    cached_metadata = get_model_metadata_from_cache(
        dataset_id=dataset_id, version_id=version_id
    )
    if cached_metadata is not None:
        return cached_metadata[0], cached_metadata[1]
    if version_id == STUB_VERSION_ID:
        if api_key is None:
            raise MissingApiKeyError(
                "Stub model version provided but no API key was provided. API key is required to load stub models."
            )
        workspace_id = get_roboflow_workspace(api_key=api_key)
        project_task_type = get_roboflow_dataset_type(
            api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
        )
        model_type = "stub"
        save_model_metadata_in_cache(
            dataset_id=dataset_id,
            version_id=version_id,
            project_task_type=project_task_type,
            model_type=model_type,
        )
        return project_task_type, model_type

    if version_id is not None:
        api_data = get_roboflow_model_data(
            api_key=api_key,
            model_id=model_id,
            endpoint_type=ModelEndpointType.ORT,
            device_id=GLOBAL_DEVICE_ID,
        ).get("ort")
        project_task_type = api_data.get("type", "object-detection")
    else:
        api_data = get_roboflow_instant_model_data(
            api_key=api_key,
            model_id=model_id,
        )
        project_task_type = api_data.get("taskType", "object-detection")
    if api_data is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
    # some older projects do not have type field - hence defaulting
    model_type = api_data.get("modelType")
    if model_type is None or model_type == "ort":
        # some very old model versions do not have modelType reported - and API respond in a generic way -
        # then we shall attempt using default model for given task type
        model_type = MODEL_TYPE_DEFAULTS.get(project_task_type)
    if model_type is None or project_task_type is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
    save_model_metadata_in_cache(
        dataset_id=dataset_id,
        version_id=version_id,
        project_task_type=project_task_type,
        model_type=model_type,
    )

    return project_task_type, model_type