Skip to content

roboflow

RoboflowModelRegistry

Bases: ModelRegistry

A Roboflow-specific model registry which gets the model type using the model id, then returns a model class based on the model type.

Source code in inference/core/registries/roboflow.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class RoboflowModelRegistry(ModelRegistry):
    """A Roboflow-specific model registry which gets the model type using the model id,
    then returns a model class based on the model type.
    """

    def get_model(self, model_id: str, api_key: str) -> Model:
        """Returns the model class based on the given model id and API key.

        Args:
            model_id (str): The ID of the model to be retrieved.
            api_key (str): The API key used to authenticate.

        Returns:
            Model: The model class corresponding to the given model ID and type.

        Raises:
            ModelNotRecognisedError: If the model type is not supported or found.
        """
        model_type = get_model_type(model_id, api_key)
        logger.debug(f"Model type: {model_type}")
        if model_type not in self.registry_dict:
            raise ModelNotRecognisedError(f"Model type not supported: {model_type}")
        return self.registry_dict[model_type]

get_model(model_id, api_key)

Returns the model class based on the given model id and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model to be retrieved.

required
api_key str

The API key used to authenticate.

required

Returns:

Name Type Description
Model Model

The model class corresponding to the given model ID and type.

Raises:

Type Description
ModelNotRecognisedError

If the model type is not supported or found.

Source code in inference/core/registries/roboflow.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_model(self, model_id: str, api_key: str) -> Model:
    """Returns the model class based on the given model id and API key.

    Args:
        model_id (str): The ID of the model to be retrieved.
        api_key (str): The API key used to authenticate.

    Returns:
        Model: The model class corresponding to the given model ID and type.

    Raises:
        ModelNotRecognisedError: If the model type is not supported or found.
    """
    model_type = get_model_type(model_id, api_key)
    logger.debug(f"Model type: {model_type}")
    if model_type not in self.registry_dict:
        raise ModelNotRecognisedError(f"Model type not supported: {model_type}")
    return self.registry_dict[model_type]

get_model_type(model_id, api_key=None)

Retrieves the model type based on the given model ID and API key.

Parameters:

Name Type Description Default
model_id str

The ID of the model.

required
api_key str

The API key used to authenticate.

None

Returns:

Name Type Description
tuple Tuple[TaskType, ModelType]

The project task type and the model type.

Raises:

Type Description
WorkspaceLoadError

If the workspace could not be loaded or if the API key is invalid.

DatasetLoadError

If the dataset could not be loaded due to invalid ID, workspace ID or version ID.

MissingDefaultModelError

If default model is not configured and API does not provide this info

MalformedRoboflowAPIResponseError

Roboflow API responds in invalid format.

Source code in inference/core/registries/roboflow.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def get_model_type(
    model_id: str,
    api_key: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
    """Retrieves the model type based on the given model ID and API key.

    Args:
        model_id (str): The ID of the model.
        api_key (str): The API key used to authenticate.

    Returns:
        tuple: The project task type and the model type.

    Raises:
        WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid.
        DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID.
        MissingDefaultModelError: If default model is not configured and API does not provide this info
        MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format.
    """
    model_id = resolve_roboflow_model_alias(model_id=model_id)
    dataset_id, version_id = get_model_id_chunks(model_id=model_id)
    if dataset_id in GENERIC_MODELS:
        logger.debug(f"Loading generic model: {dataset_id}.")
        return GENERIC_MODELS[dataset_id]
    cached_metadata = get_model_metadata_from_cache(
        dataset_id=dataset_id, version_id=version_id
    )
    if cached_metadata is not None:
        return cached_metadata[0], cached_metadata[1]
    if version_id == STUB_VERSION_ID:
        if api_key is None:
            raise MissingApiKeyError(
                "Stub model version provided but no API key was provided. API key is required to load stub models."
            )
        workspace_id = get_roboflow_workspace(api_key=api_key)
        project_task_type = get_roboflow_dataset_type(
            api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
        )
        model_type = "stub"
        save_model_metadata_in_cache(
            dataset_id=dataset_id,
            version_id=version_id,
            project_task_type=project_task_type,
            model_type=model_type,
        )
        return project_task_type, model_type
    api_data = get_roboflow_model_data(
        api_key=api_key,
        model_id=model_id,
        endpoint_type=ModelEndpointType.ORT,
        device_id=GLOBAL_DEVICE_ID,
    ).get("ort")
    if api_data is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
    # some older projects do not have type field - hence defaulting
    project_task_type = api_data.get("type", "object-detection")
    model_type = api_data.get("modelType")
    if model_type is None or model_type == "ort":
        # some very old model versions do not have modelType reported - and API respond in a generic way -
        # then we shall attempt using default model for given task type
        model_type = MODEL_TYPE_DEFAULTS.get(project_task_type)
    if model_type is None or project_task_type is None:
        raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
    save_model_metadata_in_cache(
        dataset_id=dataset_id,
        version_id=version_id,
        project_task_type=project_task_type,
        model_type=model_type,
    )

    return project_task_type, model_type