Skip to content

Commit c1eacbd

Browse files
committed
tidy up
1 parent 3dd139a commit c1eacbd

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

pykg2vec/models/pointwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,10 @@ def get_reg(self, h, r, t, reg_type="F2"):
208208

209209
if reg_type.lower() == 'f2':
210210
regul_term = torch.mean(torch.sum(h_e_real ** 2, -1) + torch.sum(h_e_img ** 2, -1) + torch.sum(r_e_real ** 2, -1) +
211-
torch.sum(r_e_img ** 2, -1) + torch.sum(t_e_real ** 2, -1) + torch.sum(t_e_img ** 2, -1))
211+
torch.sum(r_e_img ** 2, -1) + torch.sum(t_e_real ** 2, -1) + torch.sum(t_e_img ** 2, -1))
212212
elif reg_type.lower() == 'n3':
213213
regul_term = torch.mean(torch.sum(h_e_real ** 3, -1) + torch.sum(h_e_img ** 3, -1) + torch.sum(r_e_real ** 3, -1) +
214-
torch.sum(r_e_img ** 3, -1) + torch.sum(t_e_real ** 3, -1) + torch.sum(t_e_img ** 3, -1))
214+
torch.sum(r_e_img ** 3, -1) + torch.sum(t_e_real ** 3, -1) + torch.sum(t_e_img ** 3, -1))
215215
else:
216216
raise NotImplementedError('Unknown regularizer type: %s' % reg_type)
217217

pykg2vec/models/projection.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def forward(self, e, r, er_e2, direction="tail"):
214214
emb_hr_r = self.rel_embeddings(r) # [m, k]
215215

216216
if direction == "tail":
217-
ere2_sigmoid = self.g(torch.dropout(self.f1(emb_hr_e, emb_hr_r), p=self.hidden_dropout, train=True), self.ent_embeddings.weight)
217+
ere2_sigmoid = ProjE_pointwise.g(torch.dropout(self.f1(emb_hr_e, emb_hr_r), p=self.hidden_dropout, train=True), self.ent_embeddings.weight)
218218
else:
219-
ere2_sigmoid = self.g(torch.dropout(self.f2(emb_hr_e, emb_hr_r), p=self.hidden_dropout, train=True), self.ent_embeddings.weight)
219+
ere2_sigmoid = ProjE_pointwise.g(torch.dropout(self.f2(emb_hr_e, emb_hr_r), p=self.hidden_dropout, train=True), self.ent_embeddings.weight)
220220

221221
ere2_loss_left = -torch.sum((torch.log(torch.clamp(ere2_sigmoid, 1e-10, 1.0)) * torch.max(torch.FloatTensor([0]).to(self.device), er_e2)))
222222
ere2_loss_right = -torch.sum((torch.log(torch.clamp(1 - ere2_sigmoid, 1e-10, 1.0)) * torch.max(torch.FloatTensor([0]).to(self.device), torch.neg(er_e2))))
@@ -243,21 +243,11 @@ def f2(self, t, r):
243243
"""
244244
return torch.tanh(t * self.De2.weight + r * self.Dr2.weight + self.bc2.weight)
245245

246-
def g(self, f, w):
247-
"""Defines activation layer.
248-
249-
Args:
250-
f (Tensor): output of the forward layers.
251-
w (Tensor): Matrix for multiplication.
252-
"""
253-
# [b, k] [k, tot_ent]
254-
return torch.sigmoid(torch.matmul(f, w.T))
255-
256246
def predict_tail_rank(self, h, r, topk=-1):
257247
emb_h = self.ent_embeddings(h) # [1, k]
258248
emb_r = self.rel_embeddings(r) # [1, k]
259249

260-
hrt_sigmoid = -self.g(self.f1(emb_h, emb_r), self.ent_embeddings.weight)
250+
hrt_sigmoid = -ProjE_pointwise.g(self.f1(emb_h, emb_r), self.ent_embeddings.weight)
261251
_, rank = torch.topk(hrt_sigmoid, k=topk)
262252

263253
return rank
@@ -266,11 +256,21 @@ def predict_head_rank(self, t, r, topk=-1):
266256
emb_t = self.ent_embeddings(t) # [m, k]
267257
emb_r = self.rel_embeddings(r) # [m, k]
268258

269-
hrt_sigmoid = -self.g(self.f2(emb_t, emb_r), self.ent_embeddings.weight)
259+
hrt_sigmoid = -ProjE_pointwise.g(self.f2(emb_t, emb_r), self.ent_embeddings.weight)
270260
_, rank = torch.topk(hrt_sigmoid, k=topk)
271261

272262
return rank
273263

264+
@staticmethod
265+
def g(f, w):
266+
"""Defines activation layer.
267+
268+
Args:
269+
f (Tensor): output of the forward layers.
270+
w (Tensor): Matrix for multiplication.
271+
"""
272+
# [b, k] [k, tot_ent]
273+
return torch.sigmoid(torch.matmul(f, w.T))
274274

275275
class TuckER(ProjectionModel):
276276
"""

pykg2vec/utils/trainer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,7 @@ def train_step_projection(self, h, r, t, hr_t, tr_h):
173173

174174
return loss
175175

176-
def train_step_pointwise(self, h, r, t, target):
177-
preds = self.model(h, r, t)
178-
loss = self.model.loss(preds, target)
179-
loss += self.model.get_reg(h, r, t)
180-
return loss
181-
182-
def train_step_hyperbolic(self, h, r, t, target):
176+
def train_step_pointwise_hyperbolic(self, h, r, t, target):
183177
preds = self.model(h, r, t)
184178
loss = self.model.loss(preds, target)
185179
loss += self.model.get_reg(h, r, t)
@@ -289,7 +283,7 @@ def train_model_epoch(self, epoch_idx, tuning=False):
289283
r = torch.LongTensor(data[1]).to(self.config.device)
290284
t = torch.LongTensor(data[2]).to(self.config.device)
291285
y = torch.LongTensor(data[3]).to(self.config.device)
292-
loss = self.train_step_pointwise(h, r, t, y)
286+
loss = self.train_step_pointwise_hyperbolic(h, r, t, y)
293287
elif self.model.training_strategy == TrainingStrategy.PAIRWISE_BASED:
294288
pos_h = torch.LongTensor(data[0]).to(self.config.device)
295289
pos_r = torch.LongTensor(data[1]).to(self.config.device)
@@ -303,7 +297,7 @@ def train_model_epoch(self, epoch_idx, tuning=False):
303297
r = torch.cat((torch.LongTensor(data[1]).to(self.config.device), torch.LongTensor(data[4]).to(self.config.device)), dim=-1)
304298
t = torch.cat((torch.LongTensor(data[2]).to(self.config.device), torch.LongTensor(data[5]).to(self.config.device)), dim=-1)
305299
y = torch.cat((torch.ones(np.array(data[0]).shape).to(self.config.device), torch.zeros(np.array(data[3]).shape).to(self.config.device)), dim=-1)
306-
loss = self.train_step_hyperbolic(h, r, t, y)
300+
loss = self.train_step_pointwise_hyperbolic(h, r, t, y)
307301
else:
308302
raise NotImplementedError("Unknown training strategy: %s" % self.model.training_strategy)
309303

0 commit comments

Comments
 (0)