Skip to content

Aggregator Node

graphorchestrator.nodes.nodes.AggregatorNode

Bases: Node

A node that aggregates multiple states into a single state.

This node takes a list of State objects, aggregates them, and returns a new State object representing the aggregated result.

Source code in graphorchestrator\nodes\nodes.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class AggregatorNode(Node):
    """
    A node that aggregates multiple states into a single state.

    This node takes a list of State objects, aggregates them, and returns a
    new State object representing the aggregated result.
    """

    def __init__(
        self, node_id: str, aggregator_action: Callable[[List[State]], State]
    ) -> None:
        super().__init__(node_id)
        self.aggregator_action = aggregator_action
        if not getattr(aggregator_action, "is_aggregator_action", False):
            raise AggregatorActionNotDecorated(aggregator_action)

        GraphLogger.get().info(
            **wrap_constants(
                message="AggregatorNode created",
                **{
                    LC.EVENT_TYPE: "node",
                    LC.NODE_ID: self.node_id,
                    LC.NODE_TYPE: "AggregatorNode",
                    LC.ACTION: "node_created",
                    LC.CUSTOM: {"function": aggregator_action.__name__},
                },
            )
        )

    async def execute(self, states: List[State]) -> State:
        """
        Executes the aggregation logic of the node.

        Args:
            states (List[State]): The list of states to aggregate.

        Returns:
            State: The aggregated state.
        """
        log = GraphLogger.get()

        log.info(
            **wrap_constants(
                message="AggregatorNode execution started",
                **{
                    LC.EVENT_TYPE: "node",
                    LC.NODE_ID: self.node_id,
                    LC.NODE_TYPE: "AggregatorNode",
                    LC.ACTION: "execute_start",
                    LC.CUSTOM: {"input_batch_size": len(states)},
                },
            )
        )

        result = (
            await self.aggregator_action(states)
            if asyncio.iscoroutinefunction(self.aggregator_action)
            else self.aggregator_action(states)
        )

        log.info(
            **wrap_constants(
                message="AggregatorNode execution completed",
                **{
                    LC.EVENT_TYPE: "node",
                    LC.NODE_ID: self.node_id,
                    LC.NODE_TYPE: "AggregatorNode",
                    LC.ACTION: "execute_end",
                    LC.OUTPUT_SIZE: len(result.messages),
                },
            )
        )

        return result

execute(states) async

Executes the aggregation logic of the node.

Parameters:

Name Type Description Default
states List[State]

The list of states to aggregate.

required

Returns:

Name Type Description
State State

The aggregated state.

Source code in graphorchestrator\nodes\nodes.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
async def execute(self, states: List[State]) -> State:
    """
    Executes the aggregation logic of the node.

    Args:
        states (List[State]): The list of states to aggregate.

    Returns:
        State: The aggregated state.
    """
    log = GraphLogger.get()

    log.info(
        **wrap_constants(
            message="AggregatorNode execution started",
            **{
                LC.EVENT_TYPE: "node",
                LC.NODE_ID: self.node_id,
                LC.NODE_TYPE: "AggregatorNode",
                LC.ACTION: "execute_start",
                LC.CUSTOM: {"input_batch_size": len(states)},
            },
        )
    )

    result = (
        await self.aggregator_action(states)
        if asyncio.iscoroutinefunction(self.aggregator_action)
        else self.aggregator_action(states)
    )

    log.info(
        **wrap_constants(
            message="AggregatorNode execution completed",
            **{
                LC.EVENT_TYPE: "node",
                LC.NODE_ID: self.node_id,
                LC.NODE_TYPE: "AggregatorNode",
                LC.ACTION: "execute_end",
                LC.OUTPUT_SIZE: len(result.messages),
            },
        )
    )

    return result