6
6
import numpy as np
7
7
from pykg2vec .models .KGMeta import PairwiseModel
8
8
from pykg2vec .models .Domain import NamedEmbedding
9
+ from pykg2vec .utils .criterion import Criterion
9
10
10
11
11
12
class TransE (PairwiseModel ):
@@ -58,6 +59,8 @@ def __init__(self, **kwargs):
58
59
self .rel_embeddings ,
59
60
]
60
61
62
+ self .loss = Criterion .pairwise_hinge
63
+
61
64
def forward (self , h , r , t ):
62
65
"""Function to get the embedding value.
63
66
@@ -151,6 +154,8 @@ def __init__(self, **kwargs):
151
154
self .w ,
152
155
]
153
156
157
+ self .loss = Criterion .pairwise_hinge
158
+
154
159
def forward (self , h , r , t ):
155
160
h_e , r_e , t_e = self .embed (h , r , t )
156
161
@@ -243,6 +248,8 @@ def __init__(self, **kwargs):
243
248
self .rel_mappings ,
244
249
]
245
250
251
+ self .loss = Criterion .pairwise_hinge
252
+
246
253
def embed (self , h , r , t ):
247
254
"""Function to get the embedding value.
248
255
@@ -345,6 +352,8 @@ def __init__(self, **kwargs):
345
352
self .rel_embeddings ,
346
353
]
347
354
355
+ self .loss = Criterion .pairwise_hinge
356
+
348
357
def forward (self , h , r , t ):
349
358
"""Function to get the embedding value.
350
359
@@ -431,6 +440,8 @@ def __init__(self, **kwargs):
431
440
self .rel_matrix ,
432
441
]
433
442
443
+ self .loss = Criterion .pairwise_hinge
444
+
434
445
def transform (self , e , matrix ):
435
446
matrix = matrix .view (- 1 , self .ent_hidden_size , self .rel_hidden_size )
436
447
if e .shape [0 ] != matrix .shape [0 ]:
@@ -541,6 +552,8 @@ def __init__(self, **kwargs):
541
552
self .mr2 ,
542
553
]
543
554
555
+ self .loss = Criterion .pairwise_hinge
556
+
544
557
def embed (self , h , r , t ):
545
558
"""Function to get the embedding value.
546
559
@@ -639,6 +652,8 @@ def __init__(self, **kwargs):
639
652
self .bv ,
640
653
]
641
654
655
+ self .loss = Criterion .pairwise_hinge
656
+
642
657
def embed (self , h , r , t ):
643
658
"""Function to get the embedding value.
644
659
@@ -665,9 +680,9 @@ def _gu_linear(self, h, r):
665
680
Returns:
666
681
Tensors: Returns the bilinear loss.
667
682
"""
668
- mu1h = torch .matmul (self .mu1 .weight , self . transpose ( h )) # [k, b]
669
- mu2r = torch .matmul (self .mu2 .weight , self . transpose ( r )) # [k, b]
670
- return self . transpose (mu1h + mu2r + self .bu .weight ) # [b, k]
683
+ mu1h = torch .matmul (self .mu1 .weight , h . T ) # [k, b]
684
+ mu2r = torch .matmul (self .mu2 .weight , r . T ) # [k, b]
685
+ return (mu1h + mu2r + self .bu .weight ). T # [b, k]
671
686
672
687
def _gv_linear (self , r , t ):
673
688
"""Function to calculate linear loss.
@@ -679,9 +694,9 @@ def _gv_linear(self, r, t):
679
694
Returns:
680
695
Tensors: Returns the bilinear loss.
681
696
"""
682
- mv1t = torch .matmul (self .mv1 .weight , self . transpose ( t )) # [k, b]
683
- mv2r = torch .matmul (self .mv2 .weight , self . transpose ( r )) # [k, b]
684
- return self . transpose (mv1t + mv2r + self .bv .weight ) # [b, k]
697
+ mv1t = torch .matmul (self .mv1 .weight , t . T ) # [k, b]
698
+ mv2r = torch .matmul (self .mv2 .weight , r . T ) # [k, b]
699
+ return (mv1t + mv2r + self .bv .weight ). T # [b, k]
685
700
686
701
def forward (self , h , r , t ):
687
702
"""Function to that performs semanting matching.
@@ -701,11 +716,6 @@ def forward(self, h, r, t):
701
716
702
717
return - torch .sum (self ._gu_linear (norm_h , norm_r ) * self ._gv_linear (norm_r , norm_t ), 1 )
703
718
704
- @staticmethod
705
- def transpose (tensor ):
706
- dims = tuple (range (len (tensor .shape )- 1 , - 1 , - 1 )) # (rank-1...0)
707
- return tensor .permute (dims )
708
-
709
719
710
720
class SME_BL (SME ):
711
721
""" `A Semantic Matching Energy Function for Learning with Multi-relational Data`_
@@ -729,6 +739,7 @@ class SME_BL(SME):
729
739
def __init__ (self , ** kwargs ):
730
740
super (SME_BL , self ).__init__ (** kwargs )
731
741
self .model_name = self .__class__ .__name__ .lower ()
742
+ self .loss = Criterion .pairwise_hinge
732
743
733
744
def _gu_bilinear (self , h , r ):
734
745
"""Function to calculate bilinear loss.
@@ -740,9 +751,9 @@ def _gu_bilinear(self, h, r):
740
751
Returns:
741
752
Tensors: Returns the bilinear loss.
742
753
"""
743
- mu1h = torch .matmul (self .mu1 .weight , self . transpose ( h )) # [k, b]
744
- mu2r = torch .matmul (self .mu2 .weight , self . transpose ( r )) # [k, b]
745
- return self . transpose (mu1h * mu2r + self .bu .weight ) # [b, k]
754
+ mu1h = torch .matmul (self .mu1 .weight , h . T ) # [k, b]
755
+ mu2r = torch .matmul (self .mu2 .weight , r . T ) # [k, b]
756
+ return (mu1h * mu2r + self .bu .weight ). T # [b, k]
746
757
747
758
def _gv_bilinear (self , r , t ):
748
759
"""Function to calculate bilinear loss.
@@ -754,9 +765,9 @@ def _gv_bilinear(self, r, t):
754
765
Returns:
755
766
Tensors: Returns the bilinear loss.
756
767
"""
757
- mv1t = torch .matmul (self .mv1 .weight , self . transpose ( t )) # [k, b]
758
- mv2r = torch .matmul (self .mv2 .weight , self . transpose ( r )) # [k, b]
759
- return self . transpose (mv1t * mv2r + self .bv .weight ) # [b, k]
768
+ mv1t = torch .matmul (self .mv1 .weight , t . T ) # [k, b]
769
+ mv2r = torch .matmul (self .mv2 .weight , r . T ) # [k, b]
770
+ return (mv1t * mv2r + self .bv .weight ). T # [b, k]
760
771
761
772
def forward (self , h , r , t ):
762
773
"""Function to that performs semanting matching.
@@ -821,6 +832,8 @@ def __init__(self, **kwargs):
821
832
self .rel_embeddings ,
822
833
]
823
834
835
+ self .loss = Criterion .pariwise_logistic
836
+
824
837
def embed (self , h , r , t ):
825
838
"""Function to get the embedding value.
826
839
@@ -891,6 +904,8 @@ def __init__(self, **kwargs):
891
904
self .rel_matrices ,
892
905
]
893
906
907
+ self .loss = Criterion .pairwise_hinge
908
+
894
909
def embed (self , h , r , t ):
895
910
""" Function to get the embedding value.
896
911
@@ -987,6 +1002,8 @@ def __init__(self, **kwargs):
987
1002
self .mr ,
988
1003
]
989
1004
1005
+ self .loss = Criterion .pairwise_hinge
1006
+
990
1007
def train_layer (self , h , t ):
991
1008
""" Defines the forward pass training layers of the algorithm.
992
1009
@@ -1030,7 +1047,7 @@ def forward(self, h, r, t):
1030
1047
norm_t = F .normalize (t_e , p = 2 , dim = - 1 )
1031
1048
return - torch .sum (norm_r * self .train_layer (norm_h , norm_t ), - 1 )
1032
1049
1033
- def get_reg (self ):
1050
+ def get_reg (self , h , r , t ):
1034
1051
return self .lmbda * torch .sqrt (sum ([torch .sum (torch .pow (var .weight , 2 )) for var in self .parameter_list ]))
1035
1052
1036
1053
@@ -1095,6 +1112,8 @@ def __init__(self, **kwargs):
1095
1112
min_rel = torch .min (torch .FloatTensor ().new_full (self .rel_embeddings_sigma .weight .shape , self .cmax ), torch .add (self .rel_embeddings_sigma .weight , 1.0 ))
1096
1113
self .rel_embeddings_sigma .weight = nn .Parameter (torch .max (torch .FloatTensor ().new_full (self .rel_embeddings_sigma .weight .shape , self .cmin ), min_rel ))
1097
1114
1115
+ self .loss = Criterion .pairwise_hinge
1116
+
1098
1117
def forward (self , h , r , t ):
1099
1118
h_mu , h_sigma , r_mu , r_sigma , t_mu , t_sigma = self .embed (h , r , t )
1100
1119
return self ._cal_score_kl_divergence (h_mu , h_sigma , r_mu , r_sigma , t_mu , t_sigma )
@@ -1199,6 +1218,8 @@ def __init__(self, **kwargs):
1199
1218
self .rel_embeddings ,
1200
1219
]
1201
1220
1221
+ self .loss = Criterion .pairwise_hinge
1222
+
1202
1223
def forward (self , h , r , t ):
1203
1224
h_e , r_e , t_e = self .embed (h , r , t )
1204
1225
r_e = F .normalize (r_e , p = 2 , dim = - 1 )
0 commit comments