Skip to content

Commit 9419f97

Browse files
authored
Merge pull request #202 from baxtree/development-criterion
group loss functions into the single utitilty class
2 parents f35f379 + c1eacbd commit 9419f97

File tree

7 files changed

+216
-112
lines changed

7 files changed

+216
-112
lines changed

pykg2vec/models/KGMeta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ def forward(self, h, r, t):
2626
raise NotImplementedError
2727

2828
def load_params(self, param_list, kwargs):
29+
"""Function to load the hyperparameters"""
2930
for param_name in param_list:
3031
if param_name not in kwargs:
3132
raise Exception("hyperparameter %s not found!" % param_name)
3233
self.database[param_name] = kwargs[param_name]
3334
return self.database
3435

36+
def get_reg(self, h, r, t, **kwargs):
37+
"""Function to override if regularization is needed"""
38+
return 0.0
39+
3540

3641
class PairwiseModel(nn.Module, Model):
3742
""" Meta Class for KGE models with translational distance"""
@@ -74,6 +79,7 @@ def __init__(self, model_name):
7479
self.training_strategy = TrainingStrategy.PROJECTION_BASED
7580
self.database = {} # dict to store model-specific hyperparameter
7681

82+
7783
class HyperbolicSpaceModel(nn.Module, Model):
7884
""" Meta Class for KGE models of hyperbolic space"""
7985

pykg2vec/models/hyperbolic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pykg2vec.models.KGMeta import HyperbolicSpaceModel
66
from pykg2vec.models.Domain import NamedEmbedding
7+
from pykg2vec.utils.criterion import Criterion
78

89

910
class MuRP(HyperbolicSpaceModel):
@@ -50,6 +51,8 @@ def __init__(self, **kwargs):
5051
self.rel_embeddings,
5152
]
5253

54+
self.loss = Criterion.bce
55+
5356
def embed(self, h, r, t):
5457
"""Function to get the embedding value.
5558

pykg2vec/models/pairwise.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from pykg2vec.models.KGMeta import PairwiseModel
88
from pykg2vec.models.Domain import NamedEmbedding
9+
from pykg2vec.utils.criterion import Criterion
910

1011

1112
class TransE(PairwiseModel):
@@ -58,6 +59,8 @@ def __init__(self, **kwargs):
5859
self.rel_embeddings,
5960
]
6061

62+
self.loss = Criterion.pairwise_hinge
63+
6164
def forward(self, h, r, t):
6265
"""Function to get the embedding value.
6366
@@ -151,6 +154,8 @@ def __init__(self, **kwargs):
151154
self.w,
152155
]
153156

157+
self.loss = Criterion.pairwise_hinge
158+
154159
def forward(self, h, r, t):
155160
h_e, r_e, t_e = self.embed(h, r, t)
156161

@@ -243,6 +248,8 @@ def __init__(self, **kwargs):
243248
self.rel_mappings,
244249
]
245250

251+
self.loss = Criterion.pairwise_hinge
252+
246253
def embed(self, h, r, t):
247254
"""Function to get the embedding value.
248255
@@ -345,6 +352,8 @@ def __init__(self, **kwargs):
345352
self.rel_embeddings,
346353
]
347354

355+
self.loss = Criterion.pairwise_hinge
356+
348357
def forward(self, h, r, t):
349358
"""Function to get the embedding value.
350359
@@ -431,6 +440,8 @@ def __init__(self, **kwargs):
431440
self.rel_matrix,
432441
]
433442

443+
self.loss = Criterion.pairwise_hinge
444+
434445
def transform(self, e, matrix):
435446
matrix = matrix.view(-1, self.ent_hidden_size, self.rel_hidden_size)
436447
if e.shape[0] != matrix.shape[0]:
@@ -541,6 +552,8 @@ def __init__(self, **kwargs):
541552
self.mr2,
542553
]
543554

555+
self.loss = Criterion.pairwise_hinge
556+
544557
def embed(self, h, r, t):
545558
"""Function to get the embedding value.
546559
@@ -639,6 +652,8 @@ def __init__(self, **kwargs):
639652
self.bv,
640653
]
641654

655+
self.loss = Criterion.pairwise_hinge
656+
642657
def embed(self, h, r, t):
643658
"""Function to get the embedding value.
644659
@@ -665,9 +680,9 @@ def _gu_linear(self, h, r):
665680
Returns:
666681
Tensors: Returns the bilinear loss.
667682
"""
668-
mu1h = torch.matmul(self.mu1.weight, self.transpose(h)) # [k, b]
669-
mu2r = torch.matmul(self.mu2.weight, self.transpose(r)) # [k, b]
670-
return self.transpose(mu1h + mu2r + self.bu.weight) # [b, k]
683+
mu1h = torch.matmul(self.mu1.weight, h.T) # [k, b]
684+
mu2r = torch.matmul(self.mu2.weight, r.T) # [k, b]
685+
return (mu1h + mu2r + self.bu.weight).T # [b, k]
671686

672687
def _gv_linear(self, r, t):
673688
"""Function to calculate linear loss.
@@ -679,9 +694,9 @@ def _gv_linear(self, r, t):
679694
Returns:
680695
Tensors: Returns the bilinear loss.
681696
"""
682-
mv1t = torch.matmul(self.mv1.weight, self.transpose(t)) # [k, b]
683-
mv2r = torch.matmul(self.mv2.weight, self.transpose(r)) # [k, b]
684-
return self.transpose(mv1t + mv2r + self.bv.weight) # [b, k]
697+
mv1t = torch.matmul(self.mv1.weight, t.T) # [k, b]
698+
mv2r = torch.matmul(self.mv2.weight, r.T) # [k, b]
699+
return (mv1t + mv2r + self.bv.weight).T # [b, k]
685700

686701
def forward(self, h, r, t):
687702
"""Function to that performs semanting matching.
@@ -701,11 +716,6 @@ def forward(self, h, r, t):
701716

702717
return -torch.sum(self._gu_linear(norm_h, norm_r) * self._gv_linear(norm_r, norm_t), 1)
703718

704-
@staticmethod
705-
def transpose(tensor):
706-
dims = tuple(range(len(tensor.shape)-1, -1, -1)) # (rank-1...0)
707-
return tensor.permute(dims)
708-
709719

710720
class SME_BL(SME):
711721
""" `A Semantic Matching Energy Function for Learning with Multi-relational Data`_
@@ -729,6 +739,7 @@ class SME_BL(SME):
729739
def __init__(self, **kwargs):
730740
super(SME_BL, self).__init__(**kwargs)
731741
self.model_name = self.__class__.__name__.lower()
742+
self.loss = Criterion.pairwise_hinge
732743

733744
def _gu_bilinear(self, h, r):
734745
"""Function to calculate bilinear loss.
@@ -740,9 +751,9 @@ def _gu_bilinear(self, h, r):
740751
Returns:
741752
Tensors: Returns the bilinear loss.
742753
"""
743-
mu1h = torch.matmul(self.mu1.weight, self.transpose(h)) # [k, b]
744-
mu2r = torch.matmul(self.mu2.weight, self.transpose(r)) # [k, b]
745-
return self.transpose(mu1h * mu2r + self.bu.weight) # [b, k]
754+
mu1h = torch.matmul(self.mu1.weight, h.T) # [k, b]
755+
mu2r = torch.matmul(self.mu2.weight, r.T) # [k, b]
756+
return (mu1h * mu2r + self.bu.weight).T # [b, k]
746757

747758
def _gv_bilinear(self, r, t):
748759
"""Function to calculate bilinear loss.
@@ -754,9 +765,9 @@ def _gv_bilinear(self, r, t):
754765
Returns:
755766
Tensors: Returns the bilinear loss.
756767
"""
757-
mv1t = torch.matmul(self.mv1.weight, self.transpose(t)) # [k, b]
758-
mv2r = torch.matmul(self.mv2.weight, self.transpose(r)) # [k, b]
759-
return self.transpose(mv1t * mv2r + self.bv.weight) # [b, k]
768+
mv1t = torch.matmul(self.mv1.weight, t.T) # [k, b]
769+
mv2r = torch.matmul(self.mv2.weight, r.T) # [k, b]
770+
return (mv1t * mv2r + self.bv.weight).T # [b, k]
760771

761772
def forward(self, h, r, t):
762773
"""Function to that performs semanting matching.
@@ -821,6 +832,8 @@ def __init__(self, **kwargs):
821832
self.rel_embeddings,
822833
]
823834

835+
self.loss = Criterion.pariwise_logistic
836+
824837
def embed(self, h, r, t):
825838
"""Function to get the embedding value.
826839
@@ -891,6 +904,8 @@ def __init__(self, **kwargs):
891904
self.rel_matrices,
892905
]
893906

907+
self.loss = Criterion.pairwise_hinge
908+
894909
def embed(self, h, r, t):
895910
""" Function to get the embedding value.
896911
@@ -987,6 +1002,8 @@ def __init__(self, **kwargs):
9871002
self.mr,
9881003
]
9891004

1005+
self.loss = Criterion.pairwise_hinge
1006+
9901007
def train_layer(self, h, t):
9911008
""" Defines the forward pass training layers of the algorithm.
9921009
@@ -1030,7 +1047,7 @@ def forward(self, h, r, t):
10301047
norm_t = F.normalize(t_e, p=2, dim=-1)
10311048
return -torch.sum(norm_r*self.train_layer(norm_h, norm_t), -1)
10321049

1033-
def get_reg(self):
1050+
def get_reg(self, h, r, t):
10341051
return self.lmbda*torch.sqrt(sum([torch.sum(torch.pow(var.weight, 2)) for var in self.parameter_list]))
10351052

10361053

@@ -1095,6 +1112,8 @@ def __init__(self, **kwargs):
10951112
min_rel = torch.min(torch.FloatTensor().new_full(self.rel_embeddings_sigma.weight.shape, self.cmax), torch.add(self.rel_embeddings_sigma.weight, 1.0))
10961113
self.rel_embeddings_sigma.weight = nn.Parameter(torch.max(torch.FloatTensor().new_full(self.rel_embeddings_sigma.weight.shape, self.cmin), min_rel))
10971114

1115+
self.loss = Criterion.pairwise_hinge
1116+
10981117
def forward(self, h, r, t):
10991118
h_mu, h_sigma, r_mu, r_sigma, t_mu, t_sigma = self.embed(h, r, t)
11001119
return self._cal_score_kl_divergence(h_mu, h_sigma, r_mu, r_sigma, t_mu, t_sigma)
@@ -1199,6 +1218,8 @@ def __init__(self, **kwargs):
11991218
self.rel_embeddings,
12001219
]
12011220

1221+
self.loss = Criterion.pairwise_hinge
1222+
12021223
def forward(self, h, r, t):
12031224
h_e, r_e, t_e = self.embed(h, r, t)
12041225
r_e = F.normalize(r_e, p=2, dim=-1)

0 commit comments

Comments
 (0)