Skip to content

Cors

PathAwareCORSMiddleware

Bases: CORSMiddleware

Extends Starlette's CORSMiddleware to allow specifying a regex of paths that this middleware should apply to. If 'match_paths' is given, only requests matching that regex will have CORS headers applied.

Also supports Private Network Access (PNA) for local development, allowing requests from public websites to localhost.

Source code in inference/core/interfaces/http/middlewares/cors.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PathAwareCORSMiddleware(StarletteCORSMiddleware):
    """
    Extends Starlette's CORSMiddleware to allow specifying a regex of paths that
    this middleware should apply to.
    If 'match_paths' is given, only requests matching that regex will have CORS
    headers applied.

    Also supports Private Network Access (PNA) for local development, allowing
    requests from public websites to localhost.
    """

    def __init__(
        self,
        app: ASGIApp,
        match_paths: str | None = None,
        allow_origins: typing.Sequence[str] = (),
        allow_methods: typing.Sequence[str] = ("GET",),
        allow_headers: typing.Sequence[str] = (),
        allow_credentials: bool = False,
        allow_origin_regex: str | None = None,
        expose_headers: typing.Sequence[str] = (),
        max_age: int = 600,
        allow_private_network: bool = False,
    ) -> None:
        super().__init__(
            app=app,
            allow_origins=allow_origins,
            allow_methods=allow_methods,
            allow_headers=allow_headers,
            allow_credentials=allow_credentials,
            allow_origin_regex=allow_origin_regex,
            expose_headers=expose_headers,
            max_age=max_age,
        )
        self.match_paths_regex = re.compile(match_paths) if match_paths else None
        self.allow_private_network = allow_private_network
        # Store these for PNA preflight handling (not exposed by parent class)
        self._max_age = max_age
        self._allow_methods = allow_methods
        self._allow_headers = allow_headers
        self._allow_credentials = allow_credentials

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """
        Only apply the CORS logic if the path matches self.match_paths_regex
        (when provided). Otherwise, just call the wrapped 'app'.
        """
        # If it's not an HTTP request, skip the CORS processing:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        # If match_paths was supplied, check if the current path matches
        if self.match_paths_regex is not None:
            path = scope.get("path", "")
            if not self.match_paths_regex.match(path):
                # If it does NOT match, just run the app without CORS
                await self.app(scope, receive, send)
                return

        # Handle Private Network Access preflight requests
        if self.allow_private_network:
            headers = Headers(scope=scope)
            if (
                scope["method"] == "OPTIONS"
                and "access-control-request-private-network" in headers
            ):
                await self._handle_pna_preflight(scope, receive, send, headers)
                return

        # If we got here, apply the normal Starlette CORSMiddleware behavior
        await super().__call__(scope, receive, send)

    async def _handle_pna_preflight(
        self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
    ) -> None:
        """
        Handle preflight requests that include Private Network Access header.
        """
        origin = request_headers.get("origin", "")
        if self.is_allowed_origin(origin=origin):
            response_headers = {
                "access-control-allow-origin": origin,
                "access-control-allow-private-network": "true",
                "access-control-allow-methods": ", ".join(self._allow_methods),
                "access-control-max-age": str(self._max_age),
            }
            if self._allow_headers and "*" not in self._allow_headers:
                response_headers["access-control-allow-headers"] = ", ".join(
                    self._allow_headers
                )
            elif "*" in self._allow_headers:
                requested_headers = request_headers.get(
                    "access-control-request-headers", ""
                )
                if requested_headers:
                    response_headers["access-control-allow-headers"] = requested_headers
            if self._allow_credentials:
                response_headers["access-control-allow-credentials"] = "true"
            response = PlainTextResponse(
                "OK", status_code=200, headers=response_headers
            )
        else:
            response = PlainTextResponse(
                "Disallowed CORS origin", status_code=400, headers={"vary": "Origin"}
            )
        await response(scope, receive, send)

__call__(scope, receive, send) async

Only apply the CORS logic if the path matches self.match_paths_regex (when provided). Otherwise, just call the wrapped 'app'.

Source code in inference/core/interfaces/http/middlewares/cors.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    """
    Only apply the CORS logic if the path matches self.match_paths_regex
    (when provided). Otherwise, just call the wrapped 'app'.
    """
    # If it's not an HTTP request, skip the CORS processing:
    if scope["type"] != "http":
        await self.app(scope, receive, send)
        return

    # If match_paths was supplied, check if the current path matches
    if self.match_paths_regex is not None:
        path = scope.get("path", "")
        if not self.match_paths_regex.match(path):
            # If it does NOT match, just run the app without CORS
            await self.app(scope, receive, send)
            return

    # Handle Private Network Access preflight requests
    if self.allow_private_network:
        headers = Headers(scope=scope)
        if (
            scope["method"] == "OPTIONS"
            and "access-control-request-private-network" in headers
        ):
            await self._handle_pna_preflight(scope, receive, send, headers)
            return

    # If we got here, apply the normal Starlette CORSMiddleware behavior
    await super().__call__(scope, receive, send)