Skip to content

Commit 14bf59e

Browse files
arjungtensorflow-copybara
authored andcommitted
Change the default value of keep_rank and flip its semantics when unpacking neighbor features. The point of reference is the (pre-batch) rank of tensors for the corresponding sample features.
PiperOrigin-RevId: 315433595
1 parent 01babca commit 14bf59e

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

neural_structured_learning/keras/layers/neighbor_features.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self,
9999
self._feature_names = (
100100
feature_names if feature_names is None else set(feature_names))
101101

102-
def call(self, inputs, keep_rank=False):
102+
def call(self, inputs, keep_rank=True):
103103
"""Extracts neighbor features and weights from a dictionary of inputs.
104104
105105
This function is a wrapper around `utils.unpack_neighbor_features`. See
@@ -109,9 +109,10 @@ def call(self, inputs, keep_rank=False):
109109
Args:
110110
inputs: Dictionary of `tf.Tensor` features with keys for neighbors and
111111
weights described by `neighbor_config`.
112-
keep_rank: Defaults to `False`. If `True`, each value of
113-
`neighbor_features` will have an extra neighborhood size dimension at
114-
axis 1.
112+
keep_rank: Boolean indicating whether each value of `neighbor_features`
113+
retains the rank from the corresponding value in `sample_features`
114+
by merging the neighborhood size with the batch_size dimension, or
115+
contains an extra neighborhood dimension at axis 1. Defaults to `True`.
115116
116117
Returns:
117118
A tuple (sample_features, neighbor_features, neighbor_weights) of tensors.

neural_structured_learning/keras/layers/neighbor_features_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def _make_model(neighbor_config, inputs, keep_rank, weight_dtype=None):
3535
Args:
3636
neighbor_config: An instance of `configs.GraphNeighborConfig`.
3737
inputs: A `tf.keras.Input` or a nested structure of `tf.keras.Input`s.
38-
keep_rank: Whether to keep the extra neighborhood size dimention.
38+
keep_rank: Whether to retain the rank of the original input tensors by
39+
merging the neighborhood size with the batch_size dimension, or add an
40+
extra neighborhood size dimension.
3941
weight_dtype: Optional `tf.DType` for weights.
4042
4143
Returns:
@@ -105,15 +107,15 @@ def testDense(self, keep_rank):
105107
# Check that neighbors and weights are grouped together for each sample.
106108
for i in range(batch_size):
107109
self.assertAllEqual(
108-
neighbors[i] if keep_rank else
109-
neighbors[(i * num_neighbors):((i + 1) * num_neighbors)],
110+
neighbors[(i * num_neighbors):((i + 1) * num_neighbors)]
111+
if keep_rank else neighbors[i],
110112
np.stack([
111113
features['NL_nbr_0_image'][i],
112114
features['NL_nbr_1_image'][i],
113115
features['NL_nbr_2_image'][i],
114116
]))
115117
self.assertAllEqual(
116-
weights[i] if keep_rank else np.split(weights, batch_size)[i],
118+
np.split(weights, batch_size)[i] if keep_rank else weights[i],
117119
np.stack([
118120
features['NL_nbr_0_weight'][i],
119121
features['NL_nbr_1_weight'][i],
@@ -160,8 +162,8 @@ def testSparse(self, keep_rank):
160162
self.assertAllClose(
161163
weights,
162164
np.array([0.9, 0.25, 0.3, 0., 0.6, 0.75, 0.,
163-
0.]).reshape((batch_size, 2,
164-
1) if keep_rank else (batch_size * 2, 1)))
165+
0.]).reshape((batch_size * 2,
166+
1) if keep_rank else (batch_size, 2, 1)))
165167
# Check that neighbors are grouped together.
166168
dense_neighbors = self.evaluate(tf.sparse.to_dense(neighbors['input'], -1.))
167169
neighbor0 = self.evaluate(
@@ -170,8 +172,8 @@ def testSparse(self, keep_rank):
170172
tf.sparse.to_dense(features['NL_nbr_1_input'], -1))
171173
for i in range(batch_size):
172174
actual = (
173-
dense_neighbors[i]
174-
if keep_rank else np.split(dense_neighbors, batch_size)[i])
175+
np.split(dense_neighbors, batch_size)[i]
176+
if keep_rank else dense_neighbors[i])
175177
self.assertAllEqual(actual, np.stack([neighbor0[i], neighbor1[i]]))
176178

177179

neural_structured_learning/lib/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,15 +443,15 @@ def _interleave_and_merge(tensors,
443443
# This is the equivalent of tf.stack() for sparse tensors.
444444
concatenated_tensors = tf.sparse.concat(
445445
axis=1, sp_inputs=[tf.sparse.expand_dims(t, 1) for t in tensors])
446-
return (concatenated_tensors if keep_rank else tf.sparse.reshape(
447-
concatenated_tensors, shape=merged_shape))
446+
return (tf.sparse.reshape(concatenated_tensors, shape=merged_shape)
447+
if keep_rank else concatenated_tensors)
448448
else:
449449
stacked_tensors = tf.stack(tensors, axis=1)
450-
return (stacked_tensors if keep_rank else tf.reshape(
451-
stacked_tensors, shape=merged_shape))
450+
return (tf.reshape(stacked_tensors, shape=merged_shape)
451+
if keep_rank else stacked_tensors)
452452

453453

454-
def unpack_neighbor_features(features, neighbor_config, keep_rank=False):
454+
def unpack_neighbor_features(features, neighbor_config, keep_rank=True):
455455
"""Extracts sample features, neighbor features, and neighbor weights.
456456
457457
For example, suppose `features` contains a single sample feature named
@@ -520,8 +520,9 @@ def unpack_neighbor_features(features, neighbor_config, keep_rank=False):
520520
1]`, where `B` is the batch size. Neighbor weight tensors cannot be sparse
521521
tensors.
522522
neighbor_config: An instance of `nsl.configs.GraphNeighborConfig`.
523-
keep_rank: Whether to preserve the neighborhood size dimension. Defaults to
524-
`False`.
523+
keep_rank: Boolean indicating whether to retain the rank from the input or
524+
to introduce a new dimension for the neighborhood size (axis 1). Defaults
525+
to `True`.
525526
526527
Returns:
527528
sample_features: a dictionary mapping feature names to tensors. The shape

0 commit comments

Comments
 (0)