Skip to content

Dispatch manager

ResultsChecker

Class responsible for queuing asyncronous inference runs, keeping track of running requests, and awaiting their results.

Source code in inference/enterprise/parallel/dispatch_manager.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class ResultsChecker:
    """
    Class responsible for queuing asyncronous inference runs,
    keeping track of running requests, and awaiting their results.
    """

    def __init__(self, redis: Redis):
        self.tasks: Dict[str, asyncio.Event] = {}
        self.dones = dict()
        self.errors = dict()
        self.running = True
        self.redis = redis
        self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS)

    async def add_task(self, task_id: str, request: InferenceRequest):
        """
        Wait until there's available cylce to queue a task.
        When there are cycles, add the task's id to a list to keep track of its results,
        launch the preprocess celeryt task, set the task's status to in progress in redis.
        """
        await self.semaphore.acquire()
        self.tasks[task_id] = asyncio.Event()
        preprocess.s(request.dict()).delay()

    def get_result(self, task_id: str) -> Any:
        """
        Check the done tasks and errored tasks for this task id.
        """
        if task_id in self.dones:
            return self.dones.pop(task_id)
        elif task_id in self.errors:
            message = self.errors.pop(task_id)
            raise Exception(message)
        else:
            raise RuntimeError(
                "Task result not found in either success or error dict. Unreachable"
            )

    async def loop(self):
        """
        Main loop. Check all in progress tasks for their status, and if their status is final,
        (either failure or success) then add their results to the appropriate results dictionary.
        """
        async with self.redis.pubsub() as pubsub:
            await pubsub.subscribe("results")
            async for message in pubsub.listen():
                if message["type"] != "message":
                    continue
                message = orjson.loads(message["data"])
                task_id = message.pop("task_id")
                if task_id not in self.tasks:
                    continue
                self.semaphore.release()
                status = message.pop("status")
                if status == FAILURE_STATE:
                    self.errors[task_id] = message["payload"]
                elif status == SUCCESS_STATE:
                    self.dones[task_id] = message["payload"]
                else:
                    raise RuntimeError(
                        "Task result not found in possible states. Unreachable"
                    )
                self.tasks[task_id].set()
                await asyncio.sleep(0)

    async def wait_for_response(self, key: str):
        event = self.tasks[key]
        await event.wait()
        del self.tasks[key]
        return self.get_result(key)

add_task(task_id, request) async

Wait until there's available cylce to queue a task. When there are cycles, add the task's id to a list to keep track of its results, launch the preprocess celeryt task, set the task's status to in progress in redis.

Source code in inference/enterprise/parallel/dispatch_manager.py
36
37
38
39
40
41
42
43
44
async def add_task(self, task_id: str, request: InferenceRequest):
    """
    Wait until there's available cylce to queue a task.
    When there are cycles, add the task's id to a list to keep track of its results,
    launch the preprocess celeryt task, set the task's status to in progress in redis.
    """
    await self.semaphore.acquire()
    self.tasks[task_id] = asyncio.Event()
    preprocess.s(request.dict()).delay()

get_result(task_id)

Check the done tasks and errored tasks for this task id.

Source code in inference/enterprise/parallel/dispatch_manager.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_result(self, task_id: str) -> Any:
    """
    Check the done tasks and errored tasks for this task id.
    """
    if task_id in self.dones:
        return self.dones.pop(task_id)
    elif task_id in self.errors:
        message = self.errors.pop(task_id)
        raise Exception(message)
    else:
        raise RuntimeError(
            "Task result not found in either success or error dict. Unreachable"
        )

loop() async

Main loop. Check all in progress tasks for their status, and if their status is final, (either failure or success) then add their results to the appropriate results dictionary.

Source code in inference/enterprise/parallel/dispatch_manager.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
async def loop(self):
    """
    Main loop. Check all in progress tasks for their status, and if their status is final,
    (either failure or success) then add their results to the appropriate results dictionary.
    """
    async with self.redis.pubsub() as pubsub:
        await pubsub.subscribe("results")
        async for message in pubsub.listen():
            if message["type"] != "message":
                continue
            message = orjson.loads(message["data"])
            task_id = message.pop("task_id")
            if task_id not in self.tasks:
                continue
            self.semaphore.release()
            status = message.pop("status")
            if status == FAILURE_STATE:
                self.errors[task_id] = message["payload"]
            elif status == SUCCESS_STATE:
                self.dones[task_id] = message["payload"]
            else:
                raise RuntimeError(
                    "Task result not found in possible states. Unreachable"
                )
            self.tasks[task_id].set()
            await asyncio.sleep(0)