diff --git a/keras/src/layers/merging/dot.py b/keras/src/layers/merging/dot.py index 358ee768a04..b49b965828c 100644 --- a/keras/src/layers/merging/dot.py +++ b/keras/src/layers/merging/dot.py @@ -41,6 +41,7 @@ def batch_dot(x, y, axes=None): axes: Tuple or list of integers with target dimensions, or single integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal. + Note that axis `0` (the batch axis) cannot be included. Returns: A tensor with shape equal to the concatenation of `x`'s shape @@ -226,7 +227,8 @@ class Dot(Merge): take the dot product. If a tuple, should be two integers corresponding to the desired axis from the first input and the desired axis from the second input, respectively. Note that the - size of the two selected axes must match. + size of the two selected axes must match, and that + axis `0` (the batch axis) cannot be included. normalize: Whether to L2-normalize samples along the dot product axis before taking the dot product. If set to `True`, then the output of the dot product is the cosine proximity @@ -363,6 +365,7 @@ def dot(inputs, axes=-1, **kwargs): inputs: A list of input tensors (at least 2). axes: Integer or tuple of integers, axis or axes along which to take the dot product. + Note that axis `0` (the batch axis) cannot be included. normalize: Whether to L2-normalize samples along the dot product axis before taking the dot product. If set to `True`, then the output of the dot product