Skip to content

Commit ecde04f

Browse files
fchollettensorflow-copybara
authored andcommitted
Disable input spec checking in anticipation of future change which starts enforcing stricter assumptions.
PiperOrigin-RevId: 324096175
1 parent da89233 commit ecde04f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

neural_structured_learning/keras/layers/layers_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@ def _make_functional_regularized_model(distance_config):
3737
def _make_unregularized_model(inputs, num_classes):
3838
"""Makes standard 1 layer MLP with logistic regression."""
3939
x = tf.keras.layers.Dense(16, activation='relu')(inputs)
40-
return tf.keras.Model(inputs, outputs=tf.keras.layers.Dense(num_classes)(x))
40+
model = tf.keras.Model(inputs, tf.keras.layers.Dense(num_classes)(x))
41+
return model
4142

4243
# Each example has 4 features and 2 neighbors, each with an edge weight.
4344
inputs = (tf.keras.Input(shape=(4,), dtype=tf.float32, name='features'),
4445
tf.keras.Input(shape=(2, 4), dtype=tf.float32, name='neighbors'),
4546
tf.keras.Input(
4647
shape=(2, 1), dtype=tf.float32, name='neighbor_weights'))
4748
features, neighbors, neighbor_weights = inputs
49+
neighbors = tf.reshape(neighbors, (-1,) + tuple(features.shape[1:]))
50+
neighbor_weights = tf.reshape(neighbor_weights, [-1, 1])
4851
unregularized_model = _make_unregularized_model(features, 3)
4952
logits = unregularized_model(features)
5053
model = tf.keras.Model(inputs=inputs, outputs=logits)

0 commit comments

Comments
 (0)