当前位置:首页>AI快讯 >

图注意力网络原理与代码实现

发布时间:2025-10-12源自:融质(上海)科技有限公司作者:融质科技编辑部

图注意力网络(GAT)原理与代码实现

图注意力网络(Graph Attention Network, GAT)是一种基于注意力机制的图神经网络架构,专门用于处理图结构数据。它通过学习节点之间的重要性权重,动态地为邻居节点分配不同的注意力,从而更有效地聚合邻居信息。

核心原理

注意力系数计算

对于中心节点 i 和邻居节点 j,计算注意力系数:

math

e_{ij} = ext{LeakyReLU}(mathbf{a}^T [mathbf{W}mathbf{h}_i | mathbf{W}mathbf{h}_j])

其中:

h_i, h_j 是节点特征

W 是共享权重矩阵

a 是注意力向量

|| 表示向量拼接

归一化注意力权重

使用 softmax 归一化邻居节点的注意力系数:

math

lpha{ij} = rac{exp(e{ij})}{sum_{k in mathcal{N}i} exp(e{ik})}

特征聚合

加权聚合邻居节点特征:

math

mathbf{h}i’ = sigmaleft(sum{j in mathcal{N}i} lpha{ij} mathbf{W} mathbf{h}_j ight)

多头注意力

使用 K 个独立的注意力头增强模型稳定性:

math

mathbf{h}i’ = Big|{k=1}^K sigmaleft(sum_{j in mathcal{N}i} lpha{ij}^k mathbf{W}^k mathbf{h}_j ight)

PyTorch 代码实现


import torch

import torch.nn as nn

import torch.nn.functional as F

class GATLayer(nn.Module):

    def **init**(self, in_features, out_features, n_heads=1, dropout=0.6, alpha=0.2):

        super(GATLayer, self).**init**()

        self.n_heads = n_heads

        self.out_features = out_features

     共享线性变换

    self.W = nn.Linear(in_features, out_features  n_heads, bias=False)

     注意力机制参数

    self.a = nn.Parameter(torch.zeros(size=(2out_features, 1)))

    self.leakyrelu = nn.LeakyReLU(alpha)

    self.dropout = nn.Dropout(dropout)

    nn.init.xavier_uniform_(self.W.weight)

    nn.init.xavier_uniform_(self.a)

def forward(self, h, adj):

    """

    h: 节点特征矩阵 [N, in_features]

    adj: 邻接矩阵 [N, N]

    """

    N = h.size(0)

     线性变换 [N, out_featuresheads]

    h_trans = self.W(h).view(N, self.n_heads, self.out_features)

     计算注意力分数

    h_i = h_trans.repeat(1, 1, N).view(N, self.n_heads, N, self.out_features)

    h_j = h_trans.repeat(N, 1, 1).view(N, N, self.n_heads, self.out_features).permute(0,2,1,3)

     拼接特征 [N, heads, N, 2out_features]

    concat_features = torch.cat([h_i, h_j], dim=-1)

     计算注意力系数 [N, heads, N]

    e = self.leakyrelu(torch.matmul(concat_features, self.a).squeeze(-1))

     应用邻接矩阵掩码

    zero_vec = -9e15  torch.ones_like(e)

    attention = torch.where(adj > 0, e, zero_vec)

     归一化注意力权重

    attention = F.softmax(attention, dim=-1)

    attention = self.dropout(attention)

     特征聚合 [N, heads, out_features]

    h_prime = torch.matmul(attention, h_trans)

     多头输出拼接或平均

    if self.n_heads > 1:

        return h_prime.view(N, -1)

    else:

        return h_prime.squeeze(1)

**示例:两层GAT网络** class GAT(nn.Module): def **init**(self, nfeat, nhid, nclass, dropout=0.6, alpha=0.2, n_heads=8): super(GAT, self).**init**() self.layer1 = GATLayer(nfeat, nhid, n_heads, dropout, alpha) self.layer2 = GATLayer(nhidn_heads, nclass, dropout=dropout, alpha=alpha) self.dropout = dropout
def forward(self, x, adj):

    x = F.dropout(x, self.dropout, training=self.training)

    x = F.elu(self.layer1(x, adj))

    x = F.dropout(x, self.dropout, training=self.training)

    x = self.layer2(x, adj)

    return F.log_softmax(x, dim=1)

关键优势

高效计算:仅计算相邻节点对的注意力权重

可解释性:注意力权重反映节点间重要性

归纳学习:不依赖全局图结构,适用于动态图

并行计算:所有节点注意力可同时计算

典型应用场景

社交网络分析

推荐系统

分子结构预测

知识图谱推理

交通网络预测

图注意力网络通过动态学习邻居权重,克服了传统GNN的局限性,在处理复杂图结构时表现出更强的表征能力。多头注意力机制进一步增强了模型的稳定性和表达能力,使其成为图神经网络领域的重要突破。

欢迎分享转载→ https://www.shrzkj.com.cn/aikuaixun/144625.html

上一篇:图神经网络入门与实践指南

下一篇:没有了!

Copyright © 2025 融质(上海)科技有限公司 All Rights Reserved.沪ICP备2024065424号-2XML地图