A simple neural network module for relational reasoning
关系推理是智人的核心行为,目前在深度神经网络中还远远没有达到人类的标准。本期论文介绍一种结构简单且可插拔推理模型,使用纯文本QA数据 - bAbI
模型的核心理念是约束神经网络的某些能力,以便捕获到关系推理中的公共的关键属性,换句话说,关系的推理已经整合到深度神经网络中,不需要额外的学习
一个公式可以表达:
解释:输入 $ O \in \lbrace o_1,o_2,….,o_n \rbrace$,$o_i$ 为样本,$f\phi$ 和 $g\theta$ 都是MLP,其中$g\theta$可以视为衡量$o_i$之间关系的网络
模型特点有三:
- 更多的关注样本内部之间的关系
- 对数据的训练效率更高
- 操作一个集合
数据集 - bAbI
bAbI为纯文本数据集(数据在代码中’data'文件中),分别有20个任务
代码
class RelationNet(nn.Module):
def __init__(self, word_size, answer_size,
max_s_len, max_q_len, use_cuda,
story_len=20,
emb_dim=32,
story_hsz=32,
story_layers=1,
question_hsz=32,
question_layers=1):
super().__init__()
self.use_cuda = use_cuda
self.max_s_len = max_s_len
self.max_q_len = max_q_len
self.story_len = story_len
self.emb_dim = emb_dim
self.story_hsz = story_hsz
self.emb = nn.Embedding(word_size, emb_dim)
# 故事和问题都使用lstm提取上层特征
self.story_rnn = torch.nn.LSTM(input_size=emb_dim,
hidden_size=story_hsz,
num_layers=story_layers,
batch_first=True)
self.question_rnn = torch.nn.LSTM(input_size=emb_dim,
hidden_size=question_hsz,
num_layers=question_layers,
batch_first=True)
# 样本关系mlp
self.g1 = nn.Linear((2*story_len)+(2*story_hsz)+question_hsz, 256)
self.g2 = nn.Linear(256, 256)
self.g3 = nn.Linear(256, 256)
self.g4 = nn.Linear(256, 256)
self.f1 = nn.Linear(256, 256)
self.f2 = nn.Linear(256, 512)
self.f3 = nn.Linear(512, answer_size)
self._reset_parameters()
def _reset_parameters(self, stddev=0.1):
# 参数初始化
self.emb.weight.data.normal_(std=stddev)
self.g1.weight.data.normal_(std=stddev)
self.g1.bias.data.fill_(0)
self.g2.weight.data.normal_(std=stddev)
self.g2.bias.data.fill_(0)
self.g3.weight.data.normal_(std=stddev)
self.g3.bias.data.fill_(0)
self.g4.weight.data.normal_(std=stddev)
self.g4.bias.data.fill_(0)
self.f1.weight.data.normal_(std=stddev)
self.f1.bias.data.fill_(0)
self.f2.weight.data.normal_(std=stddev)
self.f2.bias.data.fill_(0)
self.f3.weight.data.normal_(std=stddev)
self.f3.bias.data.fill_(0)
def g_theta(self, x):
# 关系mlp中石油relu作为激活函数
x = F.relu_(self.g1(x))
x = F.relu_(self.g2(x))
x = F.relu_(self.g3(x))
x = F.relu_(self.g4(x))
return x
def init_tags(self):
# 加入样本位置特征
tags = torch.zeros((self.story_len, self.story_len))
if self.use_cuda:
tags = tags.cuda()
for i in range(self.story_len):
tags[i, i].fill_(1)
return tags
def forward(self, story, question):
tags = self.init_tags()
bsz = story.shape[0]
s_emb = self.emb(story)
s_emb = s_emb.view(-1, self.max_s_len, self.emb_dim)
_, (s_state, _) = self.story_rnn(s_emb)
s_state = s_state[-1, :, :]
s_state = s_state.view(-1, self.story_len, self.story_hsz)
s_tags = tags.unsqueeze(0)
s_tags = s_tags.repeat((bsz, 1, 1))
story_objects = torch.cat((s_state, s_tags), dim=2)
q_emb = self.emb(question)
_, (q_state, _) = self.question_rnn(q_emb)
q_state = q_state[-1, :, :]
sum_g_theta = 0
for i in range(self.story_len):
this_tensor = story_objects[:, i, :]
for j in range(self.story_len):
u = torch.cat(
(this_tensor, story_objects[:, j, :], q_state), dim=1)
g = self.g_theta(u)
sum_g_theta = torch.add(sum_g_theta, g)
out = F.relu(self.f1(sum_g_theta))
out = F.relu(self.f2(out))
out = self.f3(out)
return out
结果