-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 14.2.1 (23C71)
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow.js installed from (npm or script link): script link
- TensorFlow.js version (use command below): 4.20.0
- Browser version: Chrome 126.0.6478.127
- Tensorflow.js Converter Version:
Describe the current behavior
The argument numTokens of tf.layers.categoryEncoding will impact inputLayer's output shape. The input shape is changed to [..., numTokens].
Describe the expected behavior
InputLayer's output shape should be [..., sampleLength]. Since samples are all integers, and they are between 0 and numTokens. CategoryEncoding should be able to create correct output while no other constraint on inputLayer.
Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/CodePen/any notebook.
`// Tiny TFJS train / predict example.
var numFeatures = 1;
// Define the model architecture
var model = tf.sequential();
model.add(tf.layers.inputLayer({inputShape: [numFeatures]}));
model.add(tf.layers.categoryEncoding({numTokens: 10, outputMode: "count"}));
model.summary();
tfvis.show.modelSummary({name: 'Model Summary'}, model);
// Generate some synthetic data for training
// const numbers = tf.range(0, 10, 1); // Generate numbers from 0 to 99
var numbers = tf.rand([10], () => Math.floor(Math.random() * 10), 'int32'); // Generate numbers from 0 to 99
numbers.print();
var input = tf.reshape(numbers, [numFeatures, 10]);
input.print();
model.predict(input).print();`
Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.