@@ -37,14 +37,17 @@ def _make_functional_regularized_model(distance_config):
37
37
def _make_unregularized_model (inputs , num_classes ):
38
38
"""Makes standard 1 layer MLP with logistic regression."""
39
39
x = tf .keras .layers .Dense (16 , activation = 'relu' )(inputs )
40
- return tf .keras .Model (inputs , outputs = tf .keras .layers .Dense (num_classes )(x ))
40
+ model = tf .keras .Model (inputs , tf .keras .layers .Dense (num_classes )(x ))
41
+ return model
41
42
42
43
# Each example has 4 features and 2 neighbors, each with an edge weight.
43
44
inputs = (tf .keras .Input (shape = (4 ,), dtype = tf .float32 , name = 'features' ),
44
45
tf .keras .Input (shape = (2 , 4 ), dtype = tf .float32 , name = 'neighbors' ),
45
46
tf .keras .Input (
46
47
shape = (2 , 1 ), dtype = tf .float32 , name = 'neighbor_weights' ))
47
48
features , neighbors , neighbor_weights = inputs
49
+ neighbors = tf .reshape (neighbors , (- 1 ,) + tuple (features .shape [1 :]))
50
+ neighbor_weights = tf .reshape (neighbor_weights , [- 1 , 1 ])
48
51
unregularized_model = _make_unregularized_model (features , 3 )
49
52
logits = unregularized_model (features )
50
53
model = tf .keras .Model (inputs = inputs , outputs = logits )
0 commit comments