Skip to content

Commit fd265d4

Browse files
zhouhao138Responsible ML Infra Team
authored andcommitted
NA
PiperOrigin-RevId: 713149295
1 parent 6ba5f98 commit fd265d4

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

fairness_indicators/example_model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,20 @@ def get_example_model(input_feature_key: str):
8383
text_vectorization.adapt(
8484
['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random']
8585
)
86-
dense1 = keras.layers.Dense(32, activation='relu')
87-
dense2 = keras.layers.Dense(1)
86+
dense1 = keras.layers.Dense(
87+
32,
88+
activation=None,
89+
use_bias=True,
90+
kernel_initializer='glorot_uniform',
91+
bias_initializer='zeros',
92+
)
93+
dense2 = keras.layers.Dense(
94+
1,
95+
activation=None,
96+
use_bias=False,
97+
kernel_initializer='glorot_uniform',
98+
bias_initializer='zeros',
99+
)
88100

89101
inputs = tf.keras.Input(shape=(), dtype=tf.string)
90102
parsed_example = parser(inputs)

fairness_indicators/example_model_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Tests for example_model."""
15+
"""Tests for example_model.py.
16+
17+
It also serves as an example of how to use fairness indicators with a Keras
18+
model.
19+
"""
1620

1721
from __future__ import absolute_import
1822
from __future__ import division
@@ -91,7 +95,7 @@ def test_example_model(self):
9195
]),
9296
batch_size=1,
9397
)
94-
classifier.save(self._model_dir, save_format='tf')
98+
tf.saved_model.save(classifier, self._model_dir)
9599

96100
eval_config = text_format.Parse(
97101
"""

setup.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import os
1818
import sys
1919

20-
from setuptools import find_packages
21-
from setuptools import setup
20+
import setuptools
2221

2322

2423
if sys.version_info >= (3, 11):
@@ -36,32 +35,32 @@ def select_constraint(default, nightly=None, git_master=None):
3635
return git_master
3736
else:
3837
return default
39-
4038
REQUIRED_PACKAGES = [
41-
'tensorflow>=2.15,<2.16',
39+
'tensorflow>=2.16,<2.17',
4240
'tensorflow-hub>=0.16.1,<1.0.0',
43-
'tensorflow-data-validation' + select_constraint(
44-
default='>=1.15.1,<2.0.0',
45-
nightly='>=1.16.0.dev',
46-
git_master='@git+https://github.com/tensorflow/data-validation@master'),
47-
'tensorflow-model-analysis' + select_constraint(
48-
default='>=0.46,<0.47',
49-
nightly='>=0.47.0.dev',
50-
git_master='@git+https://github.com/tensorflow/model-analysis@master'),
41+
'tensorflow-data-validation'
42+
+ select_constraint(
43+
default='>=1.16.1,<2.0.0',
44+
nightly='>=1.17.0.dev',
45+
git_master='@git+https://github.com/tensorflow/data-validation@master',
46+
),
47+
'tensorflow-model-analysis'
48+
+ select_constraint(
49+
default='>=0.47.0,<0.48.0',
50+
nightly='>=0.48.0.dev',
51+
git_master='@git+https://github.com/tensorflow/model-analysis@master',
52+
),
5153
'witwidget>=1.4.4,<2',
5254
'protobuf>=3.20.3,<5',
5355
]
54-
5556
# Get version from version module.
5657
with open('fairness_indicators/version.py') as fp:
5758
globals_dict = {}
5859
exec(fp.read(), globals_dict) # pylint: disable=exec-used
5960
__version__ = globals_dict['__version__']
60-
6161
with open('README.md', 'r', encoding='utf-8') as fh:
6262
long_description = fh.read()
63-
64-
setup(
63+
setuptools.setup(
6564
name='fairness_indicators',
6665
version=__version__,
6766
description='Fairness Indicators',
@@ -70,7 +69,7 @@ def select_constraint(default, nightly=None, git_master=None):
7069
url='https://github.com/tensorflow/fairness-indicators',
7170
author='Google LLC',
7271
author_email='packages@tensorflow.org',
73-
packages=find_packages(exclude=['tensorboard_plugin']),
72+
packages=setuptools.find_packages(exclude=['tensorboard_plugin']),
7473
package_data={
7574
'fairness_indicators': ['documentation/*'],
7675
},
@@ -96,6 +95,8 @@ def select_constraint(default, nightly=None, git_master=None):
9695
'Topic :: Software Development :: Libraries :: Python Modules',
9796
],
9897
license='Apache 2.0',
99-
keywords='tensorflow model analysis fairness indicators tensorboard machine'
100-
' learning',
98+
keywords=(
99+
'tensorflow model analysis fairness indicators tensorboard machine'
100+
' learning'
101+
),
101102
)

0 commit comments

Comments
 (0)