Skip to content

The argument numTokens will impact inputLayer's output shape #8322

@linus87

Description

@linus87

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 14.2.1 (23C71)
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow.js installed from (npm or script link): script link
  • TensorFlow.js version (use command below): 4.20.0
  • Browser version: Chrome 126.0.6478.127
  • Tensorflow.js Converter Version:

Describe the current behavior
The argument numTokens of tf.layers.categoryEncoding will impact inputLayer's output shape. The input shape is changed to [..., numTokens].

Describe the expected behavior
InputLayer's output shape should be [..., sampleLength]. Since samples are all integers, and they are between 0 and numTokens. CategoryEncoding should be able to create correct output while no other constraint on inputLayer.

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/CodePen/any notebook.

`// Tiny TFJS train / predict example.
var numFeatures = 1;

// Define the model architecture
var model = tf.sequential();
model.add(tf.layers.inputLayer({inputShape: [numFeatures]}));
model.add(tf.layers.categoryEncoding({numTokens: 10, outputMode: "count"}));

model.summary();
tfvis.show.modelSummary({name: 'Model Summary'}, model);

// Generate some synthetic data for training
// const numbers = tf.range(0, 10, 1); // Generate numbers from 0 to 99
var numbers = tf.rand([10], () => Math.floor(Math.random() * 10), 'int32'); // Generate numbers from 0 to 99
numbers.print();

var input = tf.reshape(numbers, [numFeatures, 10]);
input.print();

model.predict(input).print();`

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions