Skip to content

Dispatch manager

ResultsChecker

Class responsible for queuing asyncronous inference runs, keeping track of running requests, and awaiting their results.

Source code in inference/enterprise/parallel/dispatch_manager.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class ResultsChecker:
    """
    Class responsible for queuing asyncronous inference runs,
    keeping track of running requests, and awaiting their results.
    """

    def __init__(self, redis: Redis):
        self.tasks: Dict[str, Event] = {}
        self.dones = dict()
        self.errors = dict()
        self.running = True
        self.redis = redis
        self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS)

    def add_task(self, task_id: str, request: InferenceRequest):
        """
        Wait until there's available cylce to queue a task.
        When there are cycles, add the task's id to a list to keep track of its results,
        launch the preprocess celeryt task, set the task's status to in progress in redis.
        """
        self.semaphore.acquire()
        self.tasks[task_id] = Event()
        preprocess.s(request.dict()).delay()

    def get_result(self, task_id: str) -> Any:
        """
        Check the done tasks and errored tasks for this task id.
        """
        if task_id in self.dones:
            return self.dones.pop(task_id)
        elif task_id in self.errors:
            message = self.errors.pop(task_id)
            raise Exception(message)
        else:
            raise RuntimeError(
                "Task result not found in either success or error dict. Unreachable"
            )

    def loop(self):
        """
        Main loop. Check all in progress tasks for their status, and if their status is final,
        (either failure or success) then add their results to the appropriate results dictionary.
        """
        with self.redis.pubsub() as pubsub:
            pubsub.subscribe("results")
            for message in pubsub.listen():
                if message["type"] != "message":
                    continue
                message = orjson.loads(message["data"])
                task_id = message.pop("task_id")
                if task_id not in self.tasks:
                    continue
                self.semaphore.release()
                status = message.pop("status")
                if status == FAILURE_STATE:
                    self.errors[task_id] = message["payload"]
                elif status == SUCCESS_STATE:
                    self.dones[task_id] = message["payload"]
                else:
                    raise RuntimeError(
                        "Task result not found in possible states. Unreachable"
                    )
                self.tasks[task_id].set()

    def wait_for_response(self, key: str):
        event = self.tasks[key]
        event.wait()
        del self.tasks[key]
        return self.get_result(key)

add_task(task_id, request)

Wait until there's available cylce to queue a task. When there are cycles, add the task's id to a list to keep track of its results, launch the preprocess celeryt task, set the task's status to in progress in redis.

Source code in inference/enterprise/parallel/dispatch_manager.py
36
37
38
39
40
41
42
43
44
def add_task(self, task_id: str, request: InferenceRequest):
    """
    Wait until there's available cylce to queue a task.
    When there are cycles, add the task's id to a list to keep track of its results,
    launch the preprocess celeryt task, set the task's status to in progress in redis.
    """
    self.semaphore.acquire()
    self.tasks[task_id] = Event()
    preprocess.s(request.dict()).delay()

get_result(task_id)

Check the done tasks and errored tasks for this task id.

Source code in inference/enterprise/parallel/dispatch_manager.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_result(self, task_id: str) -> Any:
    """
    Check the done tasks and errored tasks for this task id.
    """
    if task_id in self.dones:
        return self.dones.pop(task_id)
    elif task_id in self.errors:
        message = self.errors.pop(task_id)
        raise Exception(message)
    else:
        raise RuntimeError(
            "Task result not found in either success or error dict. Unreachable"
        )

loop()

Main loop. Check all in progress tasks for their status, and if their status is final, (either failure or success) then add their results to the appropriate results dictionary.

Source code in inference/enterprise/parallel/dispatch_manager.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def loop(self):
    """
    Main loop. Check all in progress tasks for their status, and if their status is final,
    (either failure or success) then add their results to the appropriate results dictionary.
    """
    with self.redis.pubsub() as pubsub:
        pubsub.subscribe("results")
        for message in pubsub.listen():
            if message["type"] != "message":
                continue
            message = orjson.loads(message["data"])
            task_id = message.pop("task_id")
            if task_id not in self.tasks:
                continue
            self.semaphore.release()
            status = message.pop("status")
            if status == FAILURE_STATE:
                self.errors[task_id] = message["payload"]
            elif status == SUCCESS_STATE:
                self.dones[task_id] = message["payload"]
            else:
                raise RuntimeError(
                    "Task result not found in possible states. Unreachable"
                )
            self.tasks[task_id].set()