Legacy GAT vs HATT 维度分析

默认配置(来自 config.yaml):

一、公共输入维度

变量 数学形状 默认数值形状 说明
ego $[B, N_a, C+7]$ $[B, 12, 11]$ 每个 agent 的自身观测
tasks $[B, N_a, M, C+4]$ $[B, 12, 10, 8]$ 每个 agent 看到的所有任务(可重复或共享)
others $[B, N_a, N_a, C+7]$ $[B, 12, 12, 11]$ 每个 agent 看到的其他 agent,对角线置 0
扁平 obs $[B \cdot N_a, (C+7) + M(C+4) + N_a(C+7)]$ $[12B, 223]$ Dict 观测 flatten 后的维度

二、Legacy GAT 维度

2.1 图节点与边

变量 数学形状 默认数值形状 说明
agent_nodes $[B, N_a, 9]$ $[B, 12, 9]$ 从 ego/others 构建的 agent 节点
task_nodes $[B, M, 9]$ $[B, 10, 9]$ 从 tasks 构建的 task 节点
x $[B(N_a+M), 9]$ $[22B, 9]$ 拼接后的所有节点特征
edge_index $[2, B \cdot E]$ $[2, 372B]$ 边连接关系,$E = 2N_aM + N_a(N_a-1) = 372$
edge_attr $[B \cdot E, 9]$ $[372B, 9]$ 边特征:A-T 为特征差,A-A 为负相似度

2.2 GATConv 层间维度

变量 数学形状 默认数值形状 说明
$h^{(0)}$ $[B(N_a+M), 9]$ $[22B, 9]$ 输入节点特征
$W^{(0)}$ $[256, 9]$ $[256, 9]$ 第一层投影矩阵
$\alpha_{ij}^{(0)}$ $[B \cdot E, 1]$ $[372B, 1]$ 第一层边注意力权重
$h^{(1)}$ $[B(N_a+M), 256]$ $[22B, 256]$ 第一层输出(1-hop 感受野)
$W^{(1)}$ $[128, 256]$ $[128, 256]$ 第二层投影矩阵
$h^{(2)}$ $[B(N_a+M), 128]$ $[22B, 128]$ 第二层输出(2-hop ≈ 全图感受野)

2.3 输出与 Actor

变量 数学形状 默认数值形状 说明
agent_emb_batch $[B, N_a, 128]$ $[B, 12, 128]$ 每个 agent 对应的 GAT 节点嵌入
global_pool_batch $[B, 128]$ $[B, 128]$ 全图节点平均池化
gat 模式输入 $[B \cdot N_a, 128]$ $[12B, 128]$ 只用 agent_emb
cat 模式输入 $[B \cdot N_a, 223+128]$ $[12B, 351]$ obs 与 agent_emb 拼接
logits $[B \cdot N_a, M]$ $[12B, 10]$ 经 MLP 映射得到,与任务无显式对应

三、HATT v1.0 维度

3.1 Encoder 输出

变量 数学形状 默认数值形状 说明
self_h $[B, N_a, H]$ $[B, 12, 64]$ ego encoder 输出,$H=64$
other_h $[B, N_a, N_a, H]$ $[B, 12, 12, 64]$ other encoder 输出
task_h $[B, N_a, M, H]$ $[B, 12, 10, 64]$ task encoder 输出(共享编码后 expand)

3.2 RelationAttention 维度

变量 数学形状 默认数值形状 说明
aa_features $[B, N_a, N_a, 3]$ $[B, 12, 12, 3]$ agent-agent 边:相对位置 2D + 同联盟标志 1D
at_features $[B, N_a, M, 3]$ $[B, 12, 10, 3]$ agent-task 边:相对位置 2D + 已分配标志 1D
$q$ $[B, N_a, K, H/K]$ $[B, 12, 4, 16]$ query,$K=4$ 头,每头 16 维
$k, v$ $[B, N_a, N, K, H/K]$ $[B, 12, 12, 4, 16]$ key/value,$N$ 为邻居数
scores $[B, N_a, N, K]$ $[B, 12, 12, 4]$ 每头每邻居的注意力分数
weights $[B, N_a, N, K]$ $[B, 12, 12, 4]$ softmax 后的注意力权重
aa_context $[B, N_a, H]$ $[B, 12, 64]$ agent-agent 关系上下文
ta_context $[B, N_a, H]$ $[B, 12, 64]$ task-agent 关系上下文

3.3 Fusion 与 Decoder

变量 数学形状 默认数值形状 说明
fusion 输入 $[B, N_a, 3H]$ $[B, 12, 192]$ cat([self_h, aa_context, ta_context])
agent_context $[B, N_a, H]$ $[B, 12, 64]$ 融合后的 agent 表示
agent_features $[B, N_a, H_s]$ $[B, 12, 64]$ agent_projection 输出,$H_s=64$
query $[B, N_a, 1, H_s]$ $[B, 12, 1, 64]$ agent 作为 query
key $[B, N_a, M, H_s]$ $[B, 12, 10, 64]$ task 作为 key
logits $[B, N_a, M]$ $[B, 12, 10]$ 点积注意力 + 边偏置,每维显式对应一个任务

四、HATT v2.0 维度

4.1 Coalition 聚合

变量 数学形状 默认数值形状 说明
coalition_sum $[B, N_a, M, H]$ $[B, 12, 10, 64]$ 按任务 ID 聚合的成员特征和
coalition_mean $[B, N_a, M, H]$ $[B, 12, 10, 64]$ 按任务 ID 聚合的成员特征均值
coalition_count $[B, N_a, M, 1]$ $[B, 12, 10, 1]$ 每个任务的成员数
normalized_count $[B, N_a, M, 1]$ $[B, 12, 10, 1]$ 成员数归一化到 $[0, 1]$

4.2 Task Context 与 CandidateEdgeDecoder

变量 数学形状 默认数值形状 说明
task_fusion 输入 $[B, N_a, M, 3H+1]$ $[B, 12, 10, 193]$ cat([task_h, coalition_sum, coalition_mean, normalized_count])
task_context $[B, N_a, M, H]$ $[B, 12, 10, 64]$ 融合联盟信息的任务表示
decoder 输入 $[B, N_a, M, H_s+H+3]$ $[B, 12, 10, 131]$ cat([agent_features, task_context, at_features])
logits $[B, N_a, M]$ $[B, 12, 10]$ 每条 agent-task 边独立打分,显式对应任务

五、关键对比

特性 Legacy GAT HATT v1.0 HATT v2.0
注意力对象 图中所有节点 agent-agent + task-agent agent-agent + coalition 聚合
中间表示 节点 embedding agent context + task embedding agent context + task context
logits 生成 MLP(obs + graph embedding) 点积注意力 decoder CandidateEdgeDecoder
任务等变性 不满足(global pool 是 invariant) 满足 满足
联盟建模 无显式聚合 scatter_add_ 聚合
最终 logits 形状 $[12B, 10]$ $[12B, 10]$ $[12B, 10]$