diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index c4209bd..46fc2e3 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -25,4 +25,3 @@ runs: run: | python -m pip install --upgrade pip pip install ${{ inputs.package-root-dir }}[test] - diff --git a/.github/workflows/ci-lint.yml b/.github/workflows/ci-lint.yml new file mode 100644 index 0000000..dede434 --- /dev/null +++ b/.github/workflows/ci-lint.yml @@ -0,0 +1,21 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4.1.7 + with: + # Ensure the full history is fetched + # This is required to run pre-commit on a specific set of commits + # TODO: Remove this when all the pre-commit issues are fixed + fetch-depth: 0 + - uses: actions/setup-python@v5.1.1 + with: + python-version: 3.13 + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..387a3ef --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# pre-commit is a tool to perform a predefined set of tasks manually and/or +# automatically before git commits are made. +# +# Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level +# +# Common tasks +# +# - Register git hooks: pre-commit install --install-hooks +# - Run on all files: pre-commit run --all-files +# +# These pre-commit hooks are run as CI. +# +# NOTE: if it can be avoided, add configs/args in pyproject.toml or below instead of creating a new `.config.file`. +# https://pre-commit.ci/#configuration +ci: + autoupdate_schedule: monthly + autofix_commit_msg: | + [pre-commit.ci] Apply automatic pre-commit fixes + +repos: + # general + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + exclude: '\.svg$' + - id: trailing-whitespace + exclude: '\.svg$' + - id: check-json + - id: check-yaml + args: [--allow-multiple-documents, --unsafe] + - id: check-toml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: ["--fix"] + - id: ruff-format diff --git a/docs/index.md b/docs/index.md index 78e960f..f51569b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,7 +64,7 @@ options { - ![Fairness Indicators on the TensorFlow blog](images/tf_full_color_primary_icon.svg) ### [Fairness Indicators on the TensorFlow blog](https://blog.tensorflow.org/2019/12/fairness-indicators-fair-ML-systems.html) - + --- [Read on the TensorFlow blog](https://blog.tensorflow.org/2019/12/fairness-indicators-fair-ML-systems.html) diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js index 0be88e0..7e48906 100644 --- a/docs/javascripts/mathjax.js +++ b/docs/javascripts/mathjax.js @@ -11,7 +11,7 @@ window.MathJax = { } }; -document$.subscribe(() => { +document$.subscribe(() => { MathJax.startup.output.clearCache() MathJax.typesetClear() MathJax.texReset() diff --git a/docs/tutorials/_toc.yaml b/docs/tutorials/_toc.yaml index cea869b..33ae209 100644 --- a/docs/tutorials/_toc.yaml +++ b/docs/tutorials/_toc.yaml @@ -13,4 +13,3 @@ toc: path: /responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_Pandas_Case_Study - title: FaceSSD example Colab path: /responsible_ai/fairness_indicators/tutorials/Facessd_Fairness_Indicators_Example_Colab - diff --git a/fairness_indicators/example_model.py b/fairness_indicators/example_model.py index d1f3dfd..f7dc553 100644 --- a/fairness_indicators/example_model.py +++ b/fairness_indicators/example_model.py @@ -21,15 +21,15 @@ from typing import Any -from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import -from tensorflow import keras import tensorflow.compat.v1 as tf import tensorflow_model_analysis as tfma +from tensorflow import keras +from fairness_indicators import fairness_indicators_metrics # noqa: F401 -TEXT_FEATURE = 'comment_text' -LABEL = 'toxicity' -SLICE = 'slice' +TEXT_FEATURE = "comment_text" +LABEL = "toxicity" +SLICE = "slice" FEATURE_MAP = { LABEL: tf.io.FixedLenFeature([], tf.float32), TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string), @@ -38,74 +38,75 @@ class ExampleParser(keras.layers.Layer): - """A Keras layer that parses the tf.Example.""" + """A Keras layer that parses the tf.Example.""" + + def __init__(self, input_feature_key): + self._input_feature_key = input_feature_key + self.input_spec = keras.layers.InputSpec(shape=(1,), dtype=tf.string) + super().__init__() - def __init__(self, input_feature_key): - self._input_feature_key = input_feature_key - self.input_spec = keras.layers.InputSpec(shape=(1,), dtype=tf.string) - super().__init__() + def compute_output_shape(self, input_shape: Any): + return [1, 1] - def compute_output_shape(self, input_shape: Any): - return [1, 1] + def call(self, serialized_examples): + def get_feature(serialized_example): + parsed_example = tf.io.parse_single_example( + serialized_example, features=FEATURE_MAP + ) + return parsed_example[self._input_feature_key] - def call(self, serialized_examples): - def get_feature(serialized_example): - parsed_example = tf.io.parse_single_example( - serialized_example, features=FEATURE_MAP - ) - return parsed_example[self._input_feature_key] - serialized_examples = tf.cast(serialized_examples, tf.string) - return tf.map_fn(get_feature, serialized_examples) + serialized_examples = tf.cast(serialized_examples, tf.string) + return tf.map_fn(get_feature, serialized_examples) class Reshaper(keras.layers.Layer): - """A Keras layer that reshapes the input.""" + """A Keras layer that reshapes the input.""" - def call(self, inputs): - return tf.reshape(inputs, (1, 32)) + def call(self, inputs): + return tf.reshape(inputs, (1, 32)) class Caster(keras.layers.Layer): - """A Keras layer that reshapes the input.""" + """A Keras layer that reshapes the input.""" - def call(self, inputs): - return tf.cast(inputs, tf.float32) + def call(self, inputs): + return tf.cast(inputs, tf.float32) def get_example_model(input_feature_key: str): - """Returns a Keras model for testing.""" - parser = ExampleParser(input_feature_key) - text_vectorization = keras.layers.TextVectorization( - max_tokens=32, - output_mode='int', - output_sequence_length=32, - ) - text_vectorization.adapt( - ['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random'] - ) - dense1 = keras.layers.Dense( - 32, - activation=None, - use_bias=True, - kernel_initializer='glorot_uniform', - bias_initializer='zeros', - ) - dense2 = keras.layers.Dense( - 1, - activation=None, - use_bias=False, - kernel_initializer='glorot_uniform', - bias_initializer='zeros', - ) - - inputs = tf.keras.Input(shape=(), dtype=tf.string) - parsed_example = parser(inputs) - text_vector = text_vectorization(parsed_example) - text_vector = Reshaper()(text_vector) - text_vector = Caster()(text_vector) - output1 = dense1(text_vector) - output2 = dense2(output1) - return tf.keras.Model(inputs=inputs, outputs=output2) + """Returns a Keras model for testing.""" + parser = ExampleParser(input_feature_key) + text_vectorization = keras.layers.TextVectorization( + max_tokens=32, + output_mode="int", + output_sequence_length=32, + ) + text_vectorization.adapt( + ["nontoxic", "toxic comment", "test comment", "abc", "abcdef", "random"] + ) + dense1 = keras.layers.Dense( + 32, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + ) + dense2 = keras.layers.Dense( + 1, + activation=None, + use_bias=False, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + ) + + inputs = tf.keras.Input(shape=(), dtype=tf.string) + parsed_example = parser(inputs) + text_vector = text_vectorization(parsed_example) + text_vector = Reshaper()(text_vector) + text_vector = Caster()(text_vector) + output1 = dense1(text_vector) + output2 = dense2(output1) + return tf.keras.Model(inputs=inputs, outputs=output2) def evaluate_model( @@ -114,23 +115,23 @@ def evaluate_model( tfma_eval_result_path, eval_config, ): - """Evaluate Model using Tensorflow Model Analysis. - - Args: - classifier_model_path: Trained classifier model to be evaluted. - validate_tf_file_path: File containing validation TFRecordDataset. - tfma_eval_result_path: Path to export tfma-related eval path. - eval_config: tfma eval_config. - """ - - eval_shared_model = tfma.default_eval_shared_model( - eval_saved_model_path=classifier_model_path, eval_config=eval_config - ) - - # Run the fairness evaluation. - tfma.run_model_analysis( - eval_shared_model=eval_shared_model, - data_location=validate_tf_file_path, - output_path=tfma_eval_result_path, - eval_config=eval_config, - ) + """Evaluate Model using Tensorflow Model Analysis. + + Args: + ---- + classifier_model_path: Trained classifier model to be evaluted. + validate_tf_file_path: File containing validation TFRecordDataset. + tfma_eval_result_path: Path to export tfma-related eval path. + eval_config: tfma eval_config. + """ + eval_shared_model = tfma.default_eval_shared_model( + eval_saved_model_path=classifier_model_path, eval_config=eval_config + ) + + # Run the fairness evaluation. + tfma.run_model_analysis( + eval_shared_model=eval_shared_model, + data_location=validate_tf_file_path, + output_path=tfma_eval_result_path, + eval_config=eval_config, + ) diff --git a/fairness_indicators/example_model_test.py b/fairness_indicators/example_model_test.py index 3d3e936..cbc3a8a 100644 --- a/fairness_indicators/example_model_test.py +++ b/fairness_indicators/example_model_test.py @@ -18,87 +18,83 @@ model. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import datetime import os import tempfile -from fairness_indicators import example_model import numpy as np import six -from tensorflow import keras import tensorflow.compat.v1 as tf import tensorflow_model_analysis as tfma - from google.protobuf import text_format +from tensorflow import keras + +from fairness_indicators import example_model tf.compat.v1.enable_eager_execution() class ExampleModelTest(tf.test.TestCase): - - def setUp(self): - super(ExampleModelTest, self).setUp() - self._base_dir = tempfile.gettempdir() - - self._model_dir = os.path.join( - self._base_dir, 'train', - datetime.datetime.now().strftime('%Y%m%d-%H%M%S')) - - def _create_example(self, comment_text, label, slice_value): - example = tf.train.Example() - example.features.feature[example_model.TEXT_FEATURE].bytes_list.value[:] = [ - six.ensure_binary(comment_text, 'utf8') - ] - example.features.feature[example_model.SLICE].bytes_list.value[:] = [ - six.ensure_binary(slice_value, 'utf8') - ] - example.features.feature[example_model.LABEL].float_list.value[:] = [label] - return example - - def _create_data(self): - examples = [] - examples.append(self._create_example('test comment', 0.0, 'slice1')) - examples.append(self._create_example('toxic comment', 1.0, 'slice1')) - examples.append(self._create_example('non-toxic comment', 0.0, 'slice1')) - examples.append(self._create_example('test comment', 1.0, 'slice2')) - examples.append(self._create_example('non-toxic comment', 0.0, 'slice2')) - examples.append(self._create_example('test comment', 0.0, 'slice3')) - examples.append(self._create_example('toxic comment', 1.0, 'slice3')) - examples.append(self._create_example('toxic comment', 1.0, 'slice3')) - examples.append( - self._create_example('non toxic comment', 0.0, 'slice3')) - examples.append(self._create_example('abc', 0.0, 'slice1')) - examples.append(self._create_example('abcdef', 0.0, 'slice3')) - examples.append(self._create_example('random', 0.0, 'slice1')) - return examples - - def _write_tf_records(self, examples): - data_location = os.path.join(self._base_dir, 'input_data.rio') - with tf.io.TFRecordWriter(data_location) as writer: - for example in examples: - writer.write(example.SerializeToString()) - return data_location - - def test_example_model(self): - data = self._create_data() - classifier = example_model.get_example_model(example_model.TEXT_FEATURE) - classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse') - classifier.fit( - tf.constant([e.SerializeToString() for e in data]), - np.array([ - e.features.feature[example_model.LABEL].float_list.value[:][0] - for e in data - ]), - batch_size=1, - ) - tf.saved_model.save(classifier, self._model_dir) - - eval_config = text_format.Parse( - """ + def setUp(self): + super(ExampleModelTest, self).setUp() + self._base_dir = tempfile.gettempdir() + + self._model_dir = os.path.join( + self._base_dir, "train", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + ) + + def _create_example(self, comment_text, label, slice_value): + example = tf.train.Example() + example.features.feature[example_model.TEXT_FEATURE].bytes_list.value[:] = [ + six.ensure_binary(comment_text, "utf8") + ] + example.features.feature[example_model.SLICE].bytes_list.value[:] = [ + six.ensure_binary(slice_value, "utf8") + ] + example.features.feature[example_model.LABEL].float_list.value[:] = [label] + return example + + def _create_data(self): + examples = [] + examples.append(self._create_example("test comment", 0.0, "slice1")) + examples.append(self._create_example("toxic comment", 1.0, "slice1")) + examples.append(self._create_example("non-toxic comment", 0.0, "slice1")) + examples.append(self._create_example("test comment", 1.0, "slice2")) + examples.append(self._create_example("non-toxic comment", 0.0, "slice2")) + examples.append(self._create_example("test comment", 0.0, "slice3")) + examples.append(self._create_example("toxic comment", 1.0, "slice3")) + examples.append(self._create_example("toxic comment", 1.0, "slice3")) + examples.append(self._create_example("non toxic comment", 0.0, "slice3")) + examples.append(self._create_example("abc", 0.0, "slice1")) + examples.append(self._create_example("abcdef", 0.0, "slice3")) + examples.append(self._create_example("random", 0.0, "slice1")) + return examples + + def _write_tf_records(self, examples): + data_location = os.path.join(self._base_dir, "input_data.rio") + with tf.io.TFRecordWriter(data_location) as writer: + for example in examples: + writer.write(example.SerializeToString()) + return data_location + + def test_example_model(self): + data = self._create_data() + classifier = example_model.get_example_model(example_model.TEXT_FEATURE) + classifier.compile(optimizer=keras.optimizers.Adam(), loss="mse") + classifier.fit( + tf.constant([e.SerializeToString() for e in data]), + np.array( + [ + e.features.feature[example_model.LABEL].float_list.value[:][0] + for e in data + ] + ), + batch_size=1, + ) + tf.saved_model.save(classifier, self._model_dir) + + eval_config = text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -117,49 +113,47 @@ def test_example_model(self): } } """, - tfma.EvalConfig(), - ) - - validate_tf_file_path = self._write_tf_records(data) - tfma_eval_result_path = os.path.join(self._model_dir, 'tfma_eval_result') - example_model.evaluate_model( - self._model_dir, - validate_tf_file_path, - tfma_eval_result_path, - eval_config, - ) - - evaluation_results = tfma.load_eval_result(tfma_eval_result_path) - - expected_slice_keys = [ - (), - (('slice', 'slice1'),), - (('slice', 'slice2'),), - (('slice', 'slice3'),), - ] - slice_keys = [ - slice_key for slice_key, _ in evaluation_results.slicing_metrics - ] - self.assertEqual(set(expected_slice_keys), set(slice_keys)) - # Verify part of the metrics of fairness indicators - metric_values = dict(evaluation_results.slicing_metrics)[( - ('slice', 'slice1'), - )][''][''] - self.assertEqual(metric_values['example_count'], {'doubleValue': 5.0}) - - self.assertEqual( - metric_values['fairness_indicators_metrics/false_positive_rate@0.1'], - {'doubleValue': 0.0}, - ) - self.assertEqual( - metric_values['fairness_indicators_metrics/false_negative_rate@0.1'], - {'doubleValue': 1.0}, - ) - self.assertEqual( - metric_values['fairness_indicators_metrics/true_positive_rate@0.1'], - {'doubleValue': 0.0}, - ) - self.assertEqual( - metric_values['fairness_indicators_metrics/true_negative_rate@0.1'], - {'doubleValue': 1.0}, - ) + tfma.EvalConfig(), + ) + + validate_tf_file_path = self._write_tf_records(data) + tfma_eval_result_path = os.path.join(self._model_dir, "tfma_eval_result") + example_model.evaluate_model( + self._model_dir, + validate_tf_file_path, + tfma_eval_result_path, + eval_config, + ) + + evaluation_results = tfma.load_eval_result(tfma_eval_result_path) + + expected_slice_keys = [ + (), + (("slice", "slice1"),), + (("slice", "slice2"),), + (("slice", "slice3"),), + ] + slice_keys = [slice_key for slice_key, _ in evaluation_results.slicing_metrics] + self.assertEqual(set(expected_slice_keys), set(slice_keys)) + # Verify part of the metrics of fairness indicators + metric_values = dict(evaluation_results.slicing_metrics)[ + (("slice", "slice1"),) + ][""][""] + self.assertEqual(metric_values["example_count"], {"doubleValue": 5.0}) + + self.assertEqual( + metric_values["fairness_indicators_metrics/false_positive_rate@0.1"], + {"doubleValue": 0.0}, + ) + self.assertEqual( + metric_values["fairness_indicators_metrics/false_negative_rate@0.1"], + {"doubleValue": 1.0}, + ) + self.assertEqual( + metric_values["fairness_indicators_metrics/true_positive_rate@0.1"], + {"doubleValue": 0.0}, + ) + self.assertEqual( + metric_values["fairness_indicators_metrics/true_negative_rate@0.1"], + {"doubleValue": 1.0}, + ) diff --git a/fairness_indicators/fairness_indicators_metrics.py b/fairness_indicators/fairness_indicators_metrics.py index 94b785c..2b7aec2 100644 --- a/fairness_indicators/fairness_indicators_metrics.py +++ b/fairness_indicators/fairness_indicators_metrics.py @@ -16,193 +16,193 @@ import collections from typing import Any, Dict, List, Optional, Sequence -from tensorflow_model_analysis.metrics import binary_confusion_matrices -from tensorflow_model_analysis.metrics import metric_types -from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.metrics import ( + binary_confusion_matrices, + metric_types, + metric_util, +) from tensorflow_model_analysis.proto import config_pb2 -FAIRNESS_INDICATORS_METRICS_NAME = 'fairness_indicators_metrics' +FAIRNESS_INDICATORS_METRICS_NAME = "fairness_indicators_metrics" FAIRNESS_INDICATORS_SUB_METRICS = ( - 'false_positive_rate', - 'false_negative_rate', - 'true_positive_rate', - 'true_negative_rate', - 'positive_rate', - 'negative_rate', - 'false_discovery_rate', - 'false_omission_rate', - 'precision', - 'recall', + "false_positive_rate", + "false_negative_rate", + "true_positive_rate", + "true_negative_rate", + "positive_rate", + "negative_rate", + "false_discovery_rate", + "false_omission_rate", + "precision", + "recall", ) DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) class FairnessIndicators(metric_types.Metric): - """Fairness indicators metrics.""" - - def computations_with_logging(self): - """Add streamz logging for fairness indicators.""" - - computations_fn = metric_util.merge_per_key_computations( - _fairness_indicators_metrics_at_thresholds - ) - - def merge_and_log_computations_fn( - eval_config: Optional[config_pb2.EvalConfig] = None, - # A tf metadata schema. - schema: Optional[Any] = None, - model_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, - sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None, - aggregation_type: Optional[metric_types.AggregationType] = None, - class_weights: Optional[Dict[int, float]] = None, - example_weighted: bool = False, - query_key: Optional[str] = None, - **kwargs + """Fairness indicators metrics.""" + + def computations_with_logging(self): + """Add streamz logging for fairness indicators.""" + computations_fn = metric_util.merge_per_key_computations( + _fairness_indicators_metrics_at_thresholds + ) + + def merge_and_log_computations_fn( + eval_config: Optional[config_pb2.EvalConfig] = None, + # A tf metadata schema. + schema: Optional[Any] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + query_key: Optional[str] = None, + **kwargs, + ): + return computations_fn( + eval_config, + schema, + model_names, + output_names, + sub_keys, + aggregation_type, + class_weights, + example_weighted, + query_key, + **kwargs, + ) + + return merge_and_log_computations_fn + + def __init__( + self, + thresholds: Sequence[float] = DEFAULT_THRESHOLDS, + name: str = FAIRNESS_INDICATORS_METRICS_NAME, ): - return computations_fn( - eval_config, - schema, - model_names, - output_names, - sub_keys, - aggregation_type, - class_weights, - example_weighted, - query_key, - **kwargs - ) - - return merge_and_log_computations_fn - - def __init__( - self, - thresholds: Sequence[float] = DEFAULT_THRESHOLDS, - name: str = FAIRNESS_INDICATORS_METRICS_NAME, - ): - """Initializes fairness indicators metrics. - - Args: - thresholds: Thresholds to use for fairness metrics. - name: Metric name. - """ - super().__init__( - self.computations_with_logging(), thresholds=thresholds, name=name - ) + """Initializes fairness indicators metrics. + + Args: + ---- + thresholds: Thresholds to use for fairness metrics. + name: Metric name. + """ + super().__init__( + self.computations_with_logging(), thresholds=thresholds, name=name + ) def calculate_digits(thresholds): - digits = [len(str(t)) - 2 for t in thresholds] - return max(max(digits), 1) + digits = [len(str(t)) - 2 for t in thresholds] + return max(max(digits), 1) def _fairness_indicators_metrics_at_thresholds( thresholds: List[float], name: str = FAIRNESS_INDICATORS_METRICS_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, - model_name: str = '', - output_name: str = '', + model_name: str = "", + output_name: str = "", aggregation_type: Optional[metric_types.AggregationType] = None, sub_key: Optional[metric_types.SubKey] = None, class_weights: Optional[Dict[int, float]] = None, example_weighted: bool = False, ) -> metric_types.MetricComputations: - """Returns computations for fairness metrics at thresholds.""" - metric_key_by_name_by_threshold = collections.defaultdict(dict) - keys = [] - digits_num = calculate_digits(thresholds) - for t in thresholds: - for m in FAIRNESS_INDICATORS_SUB_METRICS: - key = metric_types.MetricKey( - name='%s/%s@%.*f' - % ( - name, - m, - digits_num, - t, - ), # e.g. "fairness_indicators_metrics/positive_rate@0.5" - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - example_weighted=example_weighted, - ) - keys.append(key) - metric_key_by_name_by_threshold[t][m] = key - - # Make sure matrices are calculated. - computations = binary_confusion_matrices.binary_confusion_matrices( - eval_config=eval_config, - model_name=model_name, - output_name=output_name, - sub_key=sub_key, - aggregation_type=aggregation_type, - class_weights=class_weights, - example_weighted=example_weighted, - thresholds=thresholds, - ) - confusion_matrices_key = computations[-1].keys[-1] - - def result( - metrics: Dict[metric_types.MetricKey, Any], - ) -> Dict[metric_types.MetricKey, Any]: - """Returns fairness metrics values.""" - metric = metrics[confusion_matrices_key] - output = {} - - for i, threshold in enumerate(thresholds): - num_positives = metric.tp[i] + metric.fn[i] - num_negatives = metric.tn[i] + metric.fp[i] - - tpr = metric.tp[i] / (num_positives or float('nan')) - tnr = metric.tn[i] / (num_negatives or float('nan')) - fpr = metric.fp[i] / (num_negatives or float('nan')) - fnr = metric.fn[i] / (num_positives or float('nan')) - pr = (metric.tp[i] + metric.fp[i]) / ( - (num_positives + num_negatives) or float('nan') - ) - nr = (metric.tn[i] + metric.fn[i]) / ( - (num_positives + num_negatives) or float('nan') - ) - precision = metric.tp[i] / ((metric.tp[i] + metric.fp[i]) or float('nan')) - recall = metric.tp[i] / ((metric.tp[i] + metric.fn[i]) or float('nan')) - - fdr = metric.fp[i] / ((metric.fp[i] + metric.tp[i]) or float('nan')) - fomr = metric.fn[i] / ((metric.fn[i] + metric.tn[i]) or float('nan')) - - output[ - metric_key_by_name_by_threshold[threshold]['false_positive_rate'] - ] = fpr - output[ - metric_key_by_name_by_threshold[threshold]['false_negative_rate'] - ] = fnr - output[ - metric_key_by_name_by_threshold[threshold]['true_positive_rate'] - ] = tpr - output[ - metric_key_by_name_by_threshold[threshold]['true_negative_rate'] - ] = tnr - output[metric_key_by_name_by_threshold[threshold]['positive_rate']] = pr - output[metric_key_by_name_by_threshold[threshold]['negative_rate']] = nr - output[ - metric_key_by_name_by_threshold[threshold]['false_discovery_rate'] - ] = fdr - output[ - metric_key_by_name_by_threshold[threshold]['false_omission_rate'] - ] = fomr - output[metric_key_by_name_by_threshold[threshold]['precision']] = ( - precision - ) - output[metric_key_by_name_by_threshold[threshold]['recall']] = recall - - return output - - derived_computation = metric_types.DerivedMetricComputation( - keys=keys, result=result - ) - - computations.append(derived_computation) - return computations + """Returns computations for fairness metrics at thresholds.""" + metric_key_by_name_by_threshold = collections.defaultdict(dict) + keys = [] + digits_num = calculate_digits(thresholds) + for t in thresholds: + for m in FAIRNESS_INDICATORS_SUB_METRICS: + key = metric_types.MetricKey( + name="%s/%s@%.*f" + % ( + name, + m, + digits_num, + t, + ), # e.g. "fairness_indicators_metrics/positive_rate@0.5" + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + keys.append(key) + metric_key_by_name_by_threshold[t][m] = key + + # Make sure matrices are calculated. + computations = binary_confusion_matrices.binary_confusion_matrices( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + thresholds=thresholds, + ) + confusion_matrices_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns fairness metrics values.""" + metric = metrics[confusion_matrices_key] + output = {} + + for i, threshold in enumerate(thresholds): + num_positives = metric.tp[i] + metric.fn[i] + num_negatives = metric.tn[i] + metric.fp[i] + + tpr = metric.tp[i] / (num_positives or float("nan")) + tnr = metric.tn[i] / (num_negatives or float("nan")) + fpr = metric.fp[i] / (num_negatives or float("nan")) + fnr = metric.fn[i] / (num_positives or float("nan")) + pr = (metric.tp[i] + metric.fp[i]) / ( + (num_positives + num_negatives) or float("nan") + ) + nr = (metric.tn[i] + metric.fn[i]) / ( + (num_positives + num_negatives) or float("nan") + ) + precision = metric.tp[i] / ((metric.tp[i] + metric.fp[i]) or float("nan")) + recall = metric.tp[i] / ((metric.tp[i] + metric.fn[i]) or float("nan")) + + fdr = metric.fp[i] / ((metric.fp[i] + metric.tp[i]) or float("nan")) + fomr = metric.fn[i] / ((metric.fn[i] + metric.tn[i]) or float("nan")) + + output[ + metric_key_by_name_by_threshold[threshold]["false_positive_rate"] + ] = fpr + output[ + metric_key_by_name_by_threshold[threshold]["false_negative_rate"] + ] = fnr + output[metric_key_by_name_by_threshold[threshold]["true_positive_rate"]] = ( + tpr + ) + output[metric_key_by_name_by_threshold[threshold]["true_negative_rate"]] = ( + tnr + ) + output[metric_key_by_name_by_threshold[threshold]["positive_rate"]] = pr + output[metric_key_by_name_by_threshold[threshold]["negative_rate"]] = nr + output[ + metric_key_by_name_by_threshold[threshold]["false_discovery_rate"] + ] = fdr + output[ + metric_key_by_name_by_threshold[threshold]["false_omission_rate"] + ] = fomr + output[metric_key_by_name_by_threshold[threshold]["precision"]] = precision + output[metric_key_by_name_by_threshold[threshold]["recall"]] = recall + + return output + + derived_computation = metric_types.DerivedMetricComputation( + keys=keys, result=result + ) + + computations.append(derived_computation) + return computations metric_types.register_metric(FairnessIndicators) diff --git a/fairness_indicators/remediation/weight_utils.py b/fairness_indicators/remediation/weight_utils.py index 5c5abbc..deb5893 100644 --- a/fairness_indicators/remediation/weight_utils.py +++ b/fairness_indicators/remediation/weight_utils.py @@ -1,95 +1,104 @@ """Utilities to suggest weights based on model analysis results.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from typing import Any, Dict, Mapping, Text +from typing import Any, Dict, Mapping import tensorflow_model_analysis as tfma def create_percentage_difference_dictionary( - eval_result: tfma.EvalResult, - baseline_name: Text, metric_name: Text) -> Dict[Text, Any]: - """Creates dictionary of a % difference between a baseline and other slices. - - Args: - eval_result: Loaded eval result from running TensorFlow Model Analysis. - baseline_name: Name of the baseline slice, 'Overall' or a specified tuple. - metric_name: Name of the metric on which to perform comparisons. - - Returns: - Dictionary mapping slices to percentage difference from the baseline slice. - """ - baseline_value = get_baseline_value(eval_result, baseline_name, metric_name) - difference = {} - for metrics_tuple in eval_result.slicing_metrics: - slice_key = metrics_tuple[0] - metrics = metrics_tuple[1] - # Concatenate feature name/values for intersectional features. - column = '-'.join([elem[0] for elem in slice_key]) - feature_val = '-'.join([elem[1] for elem in slice_key]) - if column not in difference: - difference[column] = {} - difference[column][feature_val] = (_get_metric_value(metrics, metric_name) - - baseline_value) / baseline_value - return difference + eval_result: tfma.EvalResult, baseline_name: str, metric_name: str +) -> Dict[str, Any]: + """Creates dictionary of a % difference between a baseline and other slices. + + Args: + ---- + eval_result: Loaded eval result from running TensorFlow Model Analysis. + baseline_name: Name of the baseline slice, 'Overall' or a specified tuple. + metric_name: Name of the metric on which to perform comparisons. + + Returns: + ------- + Dictionary mapping slices to percentage difference from the baseline slice. + """ + baseline_value = get_baseline_value(eval_result, baseline_name, metric_name) + difference = {} + for metrics_tuple in eval_result.slicing_metrics: + slice_key = metrics_tuple[0] + metrics = metrics_tuple[1] + # Concatenate feature name/values for intersectional features. + column = "-".join([elem[0] for elem in slice_key]) + feature_val = "-".join([elem[1] for elem in slice_key]) + if column not in difference: + difference[column] = {} + difference[column][feature_val] = ( + _get_metric_value(metrics, metric_name) - baseline_value + ) / baseline_value + return difference def _get_metric_value( - nested_dict: Mapping[Text, Mapping[Text, Any]], metric_name: Text) -> float: - """Returns the value of the named metric from a slice's metrics. - - Args: - nested_dict: Dictionary of metrics from slice. - metric_name: Value to return from the metric slice. - - Returns: - Percentage value of the baseline slice name requested. - - Raises: - KeyError: If the metric name isn't found in the metrics dictionary or if the - input metrics dictionary is empty. - TypeError: If an unsupported value type is found within dictionary slice. - passed. - """ - for value in nested_dict.values(): - if metric_name in value['']: - typed_value = value[''][metric_name] - if 'doubleValue' in typed_value: - return typed_value['doubleValue'] - if 'boundedValue' in typed_value: - return typed_value['boundedValue']['value'] - raise TypeError('Unsupported value type: %s' % typed_value) - else: - raise KeyError('Key %s not found in %s' % - (metric_name, list(value[''].keys()))) - raise KeyError( - 'Unable to return a metric value because the dictionary passed is empty.') + nested_dict: Mapping[str, Mapping[str, Any]], metric_name: str +) -> float: + """Returns the value of the named metric from a slice's metrics. + + Args: + ---- + nested_dict: Dictionary of metrics from slice. + metric_name: Value to return from the metric slice. + + Returns: + ------- + Percentage value of the baseline slice name requested. + + Raises: + ------ + KeyError: If the metric name isn't found in the metrics dictionary or if the + input metrics dictionary is empty. + TypeError: If an unsupported value type is found within dictionary slice. + passed. + """ + for value in nested_dict.values(): + if metric_name in value[""]: + typed_value = value[""][metric_name] + if "doubleValue" in typed_value: + return typed_value["doubleValue"] + if "boundedValue" in typed_value: + return typed_value["boundedValue"]["value"] + raise TypeError("Unsupported value type: %s" % typed_value) + else: + raise KeyError( + "Key %s not found in %s" % (metric_name, list(value[""].keys())) + ) + raise KeyError( + "Unable to return a metric value because the dictionary passed is empty." + ) def get_baseline_value( - eval_result: tfma.EvalResult, - baseline_name: Text, metric_name: Text) -> float: - """Looks through the evaluation result for the value of the baseline slice. - - Args: - eval_result: Loaded eval result from running TensorFlow Model Analysis. - baseline_name: Name of the baseline slice, 'Overall' or a specified tuple. - metric_name: Name of the metric on which to perform comparisons. - - Returns: - Percentage value of the baseline slice name requested. - - Raises: - Value error if the baseline slice is not found in eval_results. - """ - for metrics_tuple in eval_result.slicing_metrics: - slice_tuple = metrics_tuple[0] - if baseline_name == 'Overall' and not slice_tuple: - return _get_metric_value(metrics_tuple[1], metric_name) - if baseline_name == slice_tuple: - return _get_metric_value(metrics_tuple[1], metric_name) - raise ValueError('Could not find baseline %s in eval_result: %s' % - (baseline_name, eval_result)) + eval_result: tfma.EvalResult, baseline_name: str, metric_name: str +) -> float: + """Looks through the evaluation result for the value of the baseline slice. + + Args: + ---- + eval_result: Loaded eval result from running TensorFlow Model Analysis. + baseline_name: Name of the baseline slice, 'Overall' or a specified tuple. + metric_name: Name of the metric on which to perform comparisons. + + Returns: + ------- + Percentage value of the baseline slice name requested. + + Raises: + ------ + Value error if the baseline slice is not found in eval_results. + """ + for metrics_tuple in eval_result.slicing_metrics: + slice_tuple = metrics_tuple[0] + if baseline_name == "Overall" and not slice_tuple: + return _get_metric_value(metrics_tuple[1], metric_name) + if baseline_name == slice_tuple: + return _get_metric_value(metrics_tuple[1], metric_name) + raise ValueError( + "Could not find baseline %s in eval_result: %s" % (baseline_name, eval_result) + ) diff --git a/fairness_indicators/remediation/weight_utils_test.py b/fairness_indicators/remediation/weight_utils_test.py index 4d7cdea..b50fe7d 100644 --- a/fairness_indicators/remediation/weight_utils_test.py +++ b/fairness_indicators/remediation/weight_utils_test.py @@ -1,217 +1,238 @@ """Tests for fairness_indicators.remediation.weight_utils.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import collections -from fairness_indicators.remediation import weight_utils import tensorflow.compat.v1 as tf +from fairness_indicators.remediation import weight_utils -EvalResult = collections.namedtuple('EvalResult', ['slicing_metrics']) +EvalResult = collections.namedtuple("EvalResult", ["slicing_metrics"]) class WeightUtilsTest(tf.test.TestCase): - - def create_eval_result(self): - return EvalResult(slicing_metrics=[ - ((), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'doubleValue': 0.08 + def create_eval_result(self): + return EvalResult( + slicing_metrics=[ + ( + (), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "doubleValue": 0.08 + }, + "accuracy": {"doubleValue": 0.444}, + } + } }, - 'accuracy': { - 'doubleValue': 0.444 - } - } - } - }), - ((('gender', 'female'),), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'doubleValue': 0.09 + ), + ( + (("gender", "female"),), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "doubleValue": 0.09 + }, + "accuracy": {"doubleValue": 0.333}, + } + } }, - 'accuracy': { - 'doubleValue': 0.333 - } - } - } - }), - (((u'gender', u'female'), - (u'sexual_orientation', u'homosexual_gay_or_lesbian')), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'doubleValue': 0.1 - }, - 'accuracy': { - 'doubleValue': 0.222 - } - } - } - }), - ]) - - def create_bounded_result(self): - return EvalResult(slicing_metrics=[ - ((), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.08, - 'methodology': 'POISSON_BOOTSTRAP' + ), + ( + ( + ("gender", "female"), + ("sexual_orientation", "homosexual_gay_or_lesbian"), + ), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "doubleValue": 0.1 + }, + "accuracy": {"doubleValue": 0.222}, + } } }, - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.444, - 'methodology': 'POISSON_BOOTSTRAP' + ), + ] + ) + + def create_bounded_result(self): + return EvalResult( + slicing_metrics=[ + ( + (), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.08, + "methodology": "POISSON_BOOTSTRAP", + } + }, + "accuracy": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.444, + "methodology": "POISSON_BOOTSTRAP", + } + }, + } } - } - } - } - }), - ((('gender', 'female'),), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.09, - 'methodology': 'POISSON_BOOTSTRAP' + }, + ), + ( + (("gender", "female"),), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.09, + "methodology": "POISSON_BOOTSTRAP", + } + }, + "accuracy": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.333, + "methodology": "POISSON_BOOTSTRAP", + } + }, + } } }, - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.333, - 'methodology': 'POISSON_BOOTSTRAP' + ), + ( + ( + ("gender", "female"), + ("sexual_orientation", "homosexual_gay_or_lesbian"), + ), + { + "": { + "": { + "post_export_metrics/negative_rate@0.10": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.1, + "methodology": "POISSON_BOOTSTRAP", + } + }, + "accuracy": { + "boundedValue": { + "lowerBound": 0.07, + "upperBound": 0.09, + "value": 0.222, + "methodology": "POISSON_BOOTSTRAP", + } + }, + } } - } - } - } - }), - (((u'gender', u'female'), - (u'sexual_orientation', u'homosexual_gay_or_lesbian')), { - '': { - '': { - 'post_export_metrics/negative_rate@0.10': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.1, - 'methodology': 'POISSON_BOOTSTRAP' - } - }, - 'accuracy': { - 'boundedValue': { - 'lowerBound': 0.07, - 'upperBound': 0.09, - 'value': 0.222, - 'methodology': 'POISSON_BOOTSTRAP' - } - } - } - } - }), - ]) - - def test_baseline(self): - test_eval_result = self.create_eval_result() - self.assertEqual( - 0.08, - weight_utils.get_baseline_value( - test_eval_result, 'Overall', - 'post_export_metrics/negative_rate@0.10')) - self.assertEqual( - 0.09, - weight_utils.get_baseline_value( - test_eval_result, (('gender', 'female'),), - 'post_export_metrics/negative_rate@0.10')) - # Test 'accuracy'. - self.assertEqual( - 0.444, - weight_utils.get_baseline_value(test_eval_result, 'Overall', - 'accuracy')) - # Test intersectional metrics. - self.assertEqual( - 0.222, - weight_utils.get_baseline_value( + }, + ), + ] + ) + + def test_baseline(self): + test_eval_result = self.create_eval_result() + self.assertEqual( + 0.08, + weight_utils.get_baseline_value( + test_eval_result, "Overall", "post_export_metrics/negative_rate@0.10" + ), + ) + self.assertEqual( + 0.09, + weight_utils.get_baseline_value( + test_eval_result, + (("gender", "female"),), + "post_export_metrics/negative_rate@0.10", + ), + ) + # Test 'accuracy'. + self.assertEqual( + 0.444, + weight_utils.get_baseline_value(test_eval_result, "Overall", "accuracy"), + ) + # Test intersectional metrics. + self.assertEqual( + 0.222, + weight_utils.get_baseline_value( + test_eval_result, + ( + ("gender", "female"), + ("sexual_orientation", "homosexual_gay_or_lesbian"), + ), + "accuracy", + ), + ) + with self.assertRaises(ValueError): + # Test slice not found. + weight_utils.get_baseline_value( + test_eval_result, (("nonexistant", "slice"),), "accuracy" + ) + with self.assertRaises(KeyError): + # Test metric not found. + weight_utils.get_baseline_value( + test_eval_result, (("gender", "female"),), "nonexistent_metric" + ) + + def test_get_metric_value_raise_key_error(self): + input_dict = {"": {"": {"accuracy": 0.1}}} + metric_name = "nonexistent_metric" + with self.assertRaises(KeyError): + weight_utils._get_metric_value(input_dict, metric_name) + + def test_get_metric_value_raise_unsupported_value(self): + input_dict = {"": {"": {"accuracy": {"boundedValue": {1}}}}} + metric_name = "accuracy" + with self.assertRaises(TypeError): + weight_utils._get_metric_value(input_dict, metric_name) + + def test_get_metric_value_raise_empty_dict(self): + with self.assertRaises(KeyError): + weight_utils._get_metric_value({}, "metric_name") + + def test_create_difference_dictionary(self): + test_eval_result = self.create_eval_result() + res = weight_utils.create_percentage_difference_dictionary( + test_eval_result, "Overall", "post_export_metrics/negative_rate@0.10" + ) + self.assertEqual(3, len(res)) + self.assertIn("gender-sexual_orientation", res) + self.assertIn("gender", res) + self.assertAlmostEqual(res["gender"]["female"], 0.125) + self.assertAlmostEqual(res[""][""], 0) + + def test_create_difference_dictionary_baseline(self): + test_eval_result = self.create_eval_result() + res = weight_utils.create_percentage_difference_dictionary( test_eval_result, - ((u'gender', u'female'), - (u'sexual_orientation', u'homosexual_gay_or_lesbian')), - 'accuracy')) - with self.assertRaises(ValueError): - # Test slice not found. - weight_utils.get_baseline_value(test_eval_result, - (('nonexistant', 'slice'),), 'accuracy') - with self.assertRaises(KeyError): - # Test metric not found. - weight_utils.get_baseline_value(test_eval_result, (('gender', 'female'),), - 'nonexistent_metric') - - def test_get_metric_value_raise_key_error(self): - input_dict = {'': {'': {'accuracy': 0.1}}} - metric_name = 'nonexistent_metric' - with self.assertRaises(KeyError): - weight_utils._get_metric_value(input_dict, metric_name) - - def test_get_metric_value_raise_unsupported_value(self): - input_dict = { - '': { - '': { - 'accuracy': { - 'boundedValue': {1} - } - } - } - } - metric_name = 'accuracy' - with self.assertRaises(TypeError): - weight_utils._get_metric_value(input_dict, metric_name) - - def test_get_metric_value_raise_empty_dict(self): - with self.assertRaises(KeyError): - weight_utils._get_metric_value({}, 'metric_name') - - def test_create_difference_dictionary(self): - test_eval_result = self.create_eval_result() - res = weight_utils.create_percentage_difference_dictionary( - test_eval_result, 'Overall', 'post_export_metrics/negative_rate@0.10') - self.assertEqual(3, len(res)) - self.assertIn('gender-sexual_orientation', res) - self.assertIn('gender', res) - self.assertAlmostEqual(res['gender']['female'], 0.125) - self.assertAlmostEqual(res[''][''], 0) - - def test_create_difference_dictionary_baseline(self): - test_eval_result = self.create_eval_result() - res = weight_utils.create_percentage_difference_dictionary( - test_eval_result, (('gender', 'female'),), - 'post_export_metrics/negative_rate@0.10') - self.assertEqual(3, len(res)) - self.assertIn('gender-sexual_orientation', res) - self.assertIn('gender', res) - self.assertAlmostEqual(res['gender']['female'], 0) - self.assertAlmostEqual(res[''][''], -0.11111111) - - def test_create_difference_dictionary_bounded_metrics(self): - test_eval_result = self.create_bounded_result() - res = weight_utils.create_percentage_difference_dictionary( - test_eval_result, 'Overall', 'post_export_metrics/negative_rate@0.10') - self.assertEqual(3, len(res)) - self.assertIn('gender-sexual_orientation', res) - self.assertIn('gender', res) - self.assertAlmostEqual(res['gender']['female'], 0.125) - self.assertAlmostEqual(res[''][''], 0) + (("gender", "female"),), + "post_export_metrics/negative_rate@0.10", + ) + self.assertEqual(3, len(res)) + self.assertIn("gender-sexual_orientation", res) + self.assertIn("gender", res) + self.assertAlmostEqual(res["gender"]["female"], 0) + self.assertAlmostEqual(res[""][""], -0.11111111) + + def test_create_difference_dictionary_bounded_metrics(self): + test_eval_result = self.create_bounded_result() + res = weight_utils.create_percentage_difference_dictionary( + test_eval_result, "Overall", "post_export_metrics/negative_rate@0.10" + ) + self.assertEqual(3, len(res)) + self.assertIn("gender-sexual_orientation", res) + self.assertIn("gender", res) + self.assertAlmostEqual(res["gender"]["female"], 0.125) + self.assertAlmostEqual(res[""][""], 0) diff --git a/fairness_indicators/tutorial_utils/__init__.py b/fairness_indicators/tutorial_utils/__init__.py index e1b9090..206a8d4 100644 --- a/fairness_indicators/tutorial_utils/__init__.py +++ b/fairness_indicators/tutorial_utils/__init__.py @@ -12,5 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Init file for fairness_indicators.tutorial_utils.""" -from fairness_indicators.tutorial_utils.util import convert_comments_data -from fairness_indicators.tutorial_utils.util import get_eval_results + +from fairness_indicators.tutorial_utils.util import ( + convert_comments_data, + get_eval_results, +) diff --git a/fairness_indicators/tutorial_utils/util.py b/fairness_indicators/tutorial_utils/util.py index 43aaec5..c7d3fd1 100644 --- a/fairness_indicators/tutorial_utils/util.py +++ b/fairness_indicators/tutorial_utils/util.py @@ -23,152 +23,175 @@ import tensorflow_model_analysis as tfma from google.protobuf import text_format -TEXT_FEATURE = 'comment_text' -LABEL = 'toxicity' +TEXT_FEATURE = "comment_text" +LABEL = "toxicity" SEXUAL_ORIENTATION_COLUMNS = [ - 'heterosexual', 'homosexual_gay_or_lesbian', 'bisexual', - 'other_sexual_orientation' + "heterosexual", + "homosexual_gay_or_lesbian", + "bisexual", + "other_sexual_orientation", ] -GENDER_COLUMNS = ['male', 'female', 'transgender', 'other_gender'] +GENDER_COLUMNS = ["male", "female", "transgender", "other_gender"] RELIGION_COLUMNS = [ - 'christian', 'jewish', 'muslim', 'hindu', 'buddhist', 'atheist', - 'other_religion' + "christian", + "jewish", + "muslim", + "hindu", + "buddhist", + "atheist", + "other_religion", ] -RACE_COLUMNS = ['black', 'white', 'asian', 'latino', 'other_race_or_ethnicity'] +RACE_COLUMNS = ["black", "white", "asian", "latino", "other_race_or_ethnicity"] DISABILITY_COLUMNS = [ - 'physical_disability', 'intellectual_or_learning_disability', - 'psychiatric_or_mental_illness', 'other_disability' + "physical_disability", + "intellectual_or_learning_disability", + "psychiatric_or_mental_illness", + "other_disability", ] IDENTITY_COLUMNS = { - 'gender': GENDER_COLUMNS, - 'sexual_orientation': SEXUAL_ORIENTATION_COLUMNS, - 'religion': RELIGION_COLUMNS, - 'race': RACE_COLUMNS, - 'disability': DISABILITY_COLUMNS + "gender": GENDER_COLUMNS, + "sexual_orientation": SEXUAL_ORIENTATION_COLUMNS, + "religion": RELIGION_COLUMNS, + "race": RACE_COLUMNS, + "disability": DISABILITY_COLUMNS, } _THRESHOLD = 0.5 def convert_comments_data(input_filename, output_filename=None): - """Convert the public civil comments data. + """Convert the public civil comments data. + + In the orginal dataset + https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data + for each indentity annotation columns, the value comes + from percent of raters thought the comment referenced the identity. When + processing the raw data, the threshold 0.5 is chosen and the identity terms + are grouped together by their categories. For example if one comment has { + male: 0.3, female: 1.0, transgender: 0.0, heterosexual: 0.8, + homosexual_gay_or_lesbian: 1.0 }. After the processing, the data will be { + gender: [female], sexual_orientation: [heterosexual, + homosexual_gay_or_lesbian] }. + + Args: + ---- + input_filename: The path to the raw civil comments data, with extension + 'tfrecord' or 'csv'. + output_filename: The path to write the processed civil comments data. + + Returns: + ------- + The file path to the converted dataset. + + Raises: + ------ + ValueError: If the input_filename does not have a supported extension. + """ + extension = os.path.splitext(input_filename)[1][1:] + + if not output_filename: + output_filename = os.path.join(tempfile.mkdtemp(), "output." + extension) + + if extension == "tfrecord": + return _convert_comments_data_tfrecord(input_filename, output_filename) + elif extension == "csv": + return _convert_comments_data_csv(input_filename, output_filename) + + raise ValueError( + "input_filename must have supported file extension csv or tfrecord, " + f"given: {input_filename}" + ) - In the orginal dataset - https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data - for each indentity annotation columns, the value comes - from percent of raters thought the comment referenced the identity. When - processing the raw data, the threshold 0.5 is chosen and the identity terms - are grouped together by their categories. For example if one comment has { - male: 0.3, female: 1.0, transgender: 0.0, heterosexual: 0.8, - homosexual_gay_or_lesbian: 1.0 }. After the processing, the data will be { - gender: [female], sexual_orientation: [heterosexual, - homosexual_gay_or_lesbian] }. - Args: - input_filename: The path to the raw civil comments data, with extension - 'tfrecord' or 'csv'. - output_filename: The path to write the processed civil comments data. - - Returns: - The file path to the converted dataset. +def _convert_comments_data_tfrecord(input_filename, output_filename=None): + """Convert the public civil comments data, for tfrecord data.""" + with tf.io.TFRecordWriter(output_filename) as writer: + for serialized in tf.data.TFRecordDataset(filenames=[input_filename]): + example = tf.train.Example() + example.ParseFromString(serialized.numpy()) + if not example.features.feature[TEXT_FEATURE].bytes_list.value: + continue + + new_example = tf.train.Example() + new_example.features.feature[TEXT_FEATURE].bytes_list.value.extend( + example.features.feature[TEXT_FEATURE].bytes_list.value + ) + new_example.features.feature[LABEL].float_list.value.append( + 1 + if example.features.feature[LABEL].float_list.value[0] >= _THRESHOLD + else 0 + ) + + for identity_category, identity_list in IDENTITY_COLUMNS.items(): + grouped_identity = [] + for identity in identity_list: + if ( + example.features.feature[identity].float_list.value + and example.features.feature[identity].float_list.value[0] + >= _THRESHOLD + ): + grouped_identity.append(identity.encode()) + new_example.features.feature[identity_category].bytes_list.value.extend( + grouped_identity + ) + writer.write(new_example.SerializeToString()) + + return output_filename - Raises: - ValueError: If the input_filename does not have a supported extension. - """ - extension = os.path.splitext(input_filename)[1][1:] - if not output_filename: - output_filename = os.path.join(tempfile.mkdtemp(), 'output.' + extension) +def _convert_comments_data_csv(input_filename, output_filename=None): + """Convert the public civil comments data, for csv data.""" + df = pd.read_csv(input_filename) - if extension == 'tfrecord': - return _convert_comments_data_tfrecord(input_filename, output_filename) - elif extension == 'csv': - return _convert_comments_data_csv(input_filename, output_filename) + # Filter out rows with empty comment text values. + df = df[df[TEXT_FEATURE].ne("")] + df = df[df[TEXT_FEATURE].notnull()] - raise ValueError( - 'input_filename must have supported file extension csv or tfrecord, ' - 'given: {}'.format(input_filename)) + new_df = pd.DataFrame() + new_df[TEXT_FEATURE] = df[TEXT_FEATURE] + # Reduce the label to value 0 or 1. + new_df[LABEL] = df[LABEL].ge(_THRESHOLD).astype(int) -def _convert_comments_data_tfrecord(input_filename, output_filename=None): - """Convert the public civil comments data, for tfrecord data.""" - with tf.io.TFRecordWriter(output_filename) as writer: - for serialized in tf.data.TFRecordDataset(filenames=[input_filename]): - example = tf.train.Example() - example.ParseFromString(serialized.numpy()) - if not example.features.feature[TEXT_FEATURE].bytes_list.value: - continue - - new_example = tf.train.Example() - new_example.features.feature[TEXT_FEATURE].bytes_list.value.extend( - example.features.feature[TEXT_FEATURE].bytes_list.value) - new_example.features.feature[LABEL].float_list.value.append( - 1 if example.features.feature[LABEL].float_list.value[0] >= _THRESHOLD - else 0) - - for identity_category, identity_list in IDENTITY_COLUMNS.items(): - grouped_identity = [] + # Extract the list of all identity terms that exceed the threshold. + def identity_conditions(df, identity_list): + group = [] for identity in identity_list: - if (example.features.feature[identity].float_list.value and - example.features.feature[identity].float_list.value[0] >= - _THRESHOLD): - grouped_identity.append(identity.encode()) - new_example.features.feature[identity_category].bytes_list.value.extend( - grouped_identity) - writer.write(new_example.SerializeToString()) - - return output_filename - - -def _convert_comments_data_csv(input_filename, output_filename=None): - """Convert the public civil comments data, for csv data.""" - df = pd.read_csv(input_filename) - - # Filter out rows with empty comment text values. - df = df[df[TEXT_FEATURE].ne('')] - df = df[df[TEXT_FEATURE].notnull()] - - new_df = pd.DataFrame() - new_df[TEXT_FEATURE] = df[TEXT_FEATURE] - - # Reduce the label to value 0 or 1. - new_df[LABEL] = df[LABEL].ge(_THRESHOLD).astype(int) - - # Extract the list of all identity terms that exceed the threshold. - def identity_conditions(df, identity_list): - group = [] - for identity in identity_list: - if df[identity] >= _THRESHOLD: - group.append(identity) - return group - - for identity_category, identity_list in IDENTITY_COLUMNS.items(): - new_df[identity_category] = df.apply( - identity_conditions, args=((identity_list),), axis=1) - - new_df.to_csv( - output_filename, - header=[TEXT_FEATURE, LABEL, *IDENTITY_COLUMNS.keys()], - index=False) - - return output_filename - - -def get_eval_results(model_location, - eval_result_path, - validate_tfrecord_file, - slice_selection='religion', - thresholds=None, - compute_confidence_intervals=True): - """Get Fairness Indicators eval results.""" - if thresholds is None: - thresholds = [0.4, 0.4125, 0.425, 0.4375, 0.45, 0.4675, 0.475, 0.4875, 0.5] - - # Define slices that you want the evaluation to run on. - eval_config = text_format.Parse( - """ + if df[identity] >= _THRESHOLD: + group.append(identity) + return group + + for identity_category, identity_list in IDENTITY_COLUMNS.items(): + new_df[identity_category] = df.apply( + identity_conditions, args=((identity_list),), axis=1 + ) + + new_df.to_csv( + output_filename, + header=[TEXT_FEATURE, LABEL, *IDENTITY_COLUMNS.keys()], + index=False, + ) + + return output_filename + + +def get_eval_results( + model_location, + eval_result_path, + validate_tfrecord_file, + slice_selection="religion", + thresholds=None, + compute_confidence_intervals=True, +): + """Get Fairness Indicators eval results.""" + if thresholds is None: + thresholds = [0.4, 0.4125, 0.425, 0.4375, 0.45, 0.4675, 0.475, 0.4875, 0.5] + + # Define slices that you want the evaluation to run on. + eval_config = text_format.Parse( + """ model_specs { label_key: '%s' } @@ -189,18 +212,26 @@ def get_eval_results(model_location, compute_confidence_intervals { value: %s } disabled_outputs{values: "analysis"} } - """ % (LABEL, thresholds, - slice_selection, 'true' if compute_confidence_intervals else 'false'), - tfma.EvalConfig()) - - eval_shared_model = tfma.default_eval_shared_model( - eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING]) - - # Run the fairness evaluation. - return tfma.run_model_analysis( - eval_shared_model=eval_shared_model, - data_location=validate_tfrecord_file, - file_format='tfrecords', - eval_config=eval_config, - output_path=eval_result_path, - extractors=None) + """ + % ( + LABEL, + thresholds, + slice_selection, + "true" if compute_confidence_intervals else "false", + ), + tfma.EvalConfig(), + ) + + eval_shared_model = tfma.default_eval_shared_model( + eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING] + ) + + # Run the fairness evaluation. + return tfma.run_model_analysis( + eval_shared_model=eval_shared_model, + data_location=validate_tfrecord_file, + file_format="tfrecords", + eval_config=eval_config, + output_path=eval_result_path, + extractors=None, + ) diff --git a/fairness_indicators/tutorial_utils/util_test.py b/fairness_indicators/tutorial_utils/util_test.py index ef2ec44..36666c6 100644 --- a/fairness_indicators/tutorial_utils/util_test.py +++ b/fairness_indicators/tutorial_utils/util_test.py @@ -14,26 +14,23 @@ # ============================================================================== """Tests for fairness_indicators.tutorial_utils.util.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import csv import os import tempfile from unittest import mock -from fairness_indicators.tutorial_utils import util + import pandas as pd import tensorflow as tf import tensorflow_model_analysis as tfma from google.protobuf import text_format +from fairness_indicators.tutorial_utils import util -class UtilTest(tf.test.TestCase): - def _create_example_tfrecord(self): - example = text_format.Parse( - """ +class UtilTest(tf.test.TestCase): + def _create_example_tfrecord(self): + example = text_format.Parse( + """ features { feature { key: "comment_text" value { bytes_list { value: [ "comment 1" ] }} @@ -85,39 +82,43 @@ def _create_example_tfrecord(self): value { float_list { value: [ 1.0 ] }} } } - """, tf.train.Example()) - empty_comment_example = text_format.Parse( - """ + """, + tf.train.Example(), + ) + empty_comment_example = text_format.Parse( + """ features { feature { key: "comment_text" value { bytes_list {} } } feature { key: "toxicity" value { float_list { value: [ 0.1 ] }}} } - """, tf.train.Example()) - return [example, empty_comment_example] + """, + tf.train.Example(), + ) + return [example, empty_comment_example] - def _write_tf_records(self, examples): - filename = os.path.join(tempfile.mkdtemp(), 'input.tfrecord') - with tf.io.TFRecordWriter(filename) as writer: - for e in examples: - writer.write(e.SerializeToString()) - return filename + def _write_tf_records(self, examples): + filename = os.path.join(tempfile.mkdtemp(), "input.tfrecord") + with tf.io.TFRecordWriter(filename) as writer: + for e in examples: + writer.write(e.SerializeToString()) + return filename - def test_convert_data_tfrecord(self): - input_file = self._write_tf_records(self._create_example_tfrecord()) - output_file = util.convert_comments_data(input_file) - output_example_list = [] - for serialized in tf.data.TFRecordDataset(filenames=[output_file]): - output_example = tf.train.Example() - output_example.ParseFromString(serialized.numpy()) - output_example_list.append(output_example) + def test_convert_data_tfrecord(self): + input_file = self._write_tf_records(self._create_example_tfrecord()) + output_file = util.convert_comments_data(input_file) + output_example_list = [] + for serialized in tf.data.TFRecordDataset(filenames=[output_file]): + output_example = tf.train.Example() + output_example.ParseFromString(serialized.numpy()) + output_example_list.append(output_example) - self.assertEqual(len(output_example_list), 1) - self.assertEqual( - output_example_list[0], - text_format.Parse( - """ + self.assertEqual(len(output_example_list), 1) + self.assertEqual( + output_example_list[0], + text_format.Parse( + """ features { feature { key: "comment_text" value { bytes_list {value: [ "comment 1" ] }} @@ -143,157 +144,170 @@ def test_convert_data_tfrecord(self): "other_disability"] }} } } - """, tf.train.Example())) + """, + tf.train.Example(), + ), + ) - def _create_example_csv(self, use_fake_embedding=False): - header = [ - 'comment_text', - 'toxicity', - 'heterosexual', - 'homosexual_gay_or_lesbian', - 'bisexual', - 'other_sexual_orientation', - 'male', - 'female', - 'transgender', - 'other_gender', - 'christian', - 'jewish', - 'muslim', - 'hindu', - 'buddhist', - 'atheist', - 'other_religion', - 'black', - 'white', - 'asian', - 'latino', - 'other_race_or_ethnicity', - 'physical_disability', - 'intellectual_or_learning_disability', - 'psychiatric_or_mental_illness', - 'other_disability', - ] - example = [ - 'comment 1' if not use_fake_embedding else 0.35, - 0.1, - # sexual orientation - 0.1, - 0.1, - 0.5, - 0.1, - # gender - 0.1, - 0.2, - 0.3, - 0.4, - # religion - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - # race or ethnicity - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - # disability - 0.6, - 0.7, - 0.8, - 1.0, - ] - empty_comment_example = [ - '' if not use_fake_embedding else 0.35, - 0.1, - 0.1, - 0.1, - 0.5, - 0.1, - 0.1, - 0.2, - 0.3, - 0.4, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 1.0, - ] - return [header, example, empty_comment_example] + def _create_example_csv(self, use_fake_embedding=False): + header = [ + "comment_text", + "toxicity", + "heterosexual", + "homosexual_gay_or_lesbian", + "bisexual", + "other_sexual_orientation", + "male", + "female", + "transgender", + "other_gender", + "christian", + "jewish", + "muslim", + "hindu", + "buddhist", + "atheist", + "other_religion", + "black", + "white", + "asian", + "latino", + "other_race_or_ethnicity", + "physical_disability", + "intellectual_or_learning_disability", + "psychiatric_or_mental_illness", + "other_disability", + ] + example = [ + "comment 1" if not use_fake_embedding else 0.35, + 0.1, + # sexual orientation + 0.1, + 0.1, + 0.5, + 0.1, + # gender + 0.1, + 0.2, + 0.3, + 0.4, + # religion + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + # race or ethnicity + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + # disability + 0.6, + 0.7, + 0.8, + 1.0, + ] + empty_comment_example = [ + "" if not use_fake_embedding else 0.35, + 0.1, + 0.1, + 0.1, + 0.5, + 0.1, + 0.1, + 0.2, + 0.3, + 0.4, + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 1.0, + ] + return [header, example, empty_comment_example] - def _write_csv(self, examples): - filename = os.path.join(tempfile.mkdtemp(), 'input.csv') - with open(filename, 'w', newline='') as csvfile: - csvwriter = csv.writer(csvfile, delimiter=',') - for example in examples: - csvwriter.writerow(example) + def _write_csv(self, examples): + filename = os.path.join(tempfile.mkdtemp(), "input.csv") + with open(filename, "w", newline="") as csvfile: + csvwriter = csv.writer(csvfile, delimiter=",") + for example in examples: + csvwriter.writerow(example) - return filename + return filename - def test_convert_data_csv(self): - input_file = self._write_csv(self._create_example_csv()) - output_file = util.convert_comments_data(input_file) + def test_convert_data_csv(self): + input_file = self._write_csv(self._create_example_csv()) + output_file = util.convert_comments_data(input_file) - # Remove the quotes around identity terms list that read_csv injects. - df = pd.read_csv(output_file).replace("'", '', regex=True) + # Remove the quotes around identity terms list that read_csv injects. + df = pd.read_csv(output_file).replace("'", "", regex=True) - expected_df = pd.DataFrame() - expected_df = pd.concat([expected_df, pd.DataFrame.from_dict( - { - 'comment_text': - ['comment 1'], - 'toxicity': - [0.0], - 'gender': [[]], - 'sexual_orientation': [['bisexual']], - 'race': [['other_race_or_ethnicity']], - 'religion': [['atheist', 'other_religion']], - 'disability': [[ - 'physical_disability', 'intellectual_or_learning_disability', - 'psychiatric_or_mental_illness', 'other_disability' - ]] - })], - ignore_index=True) + expected_df = pd.DataFrame() + expected_df = pd.concat( + [ + expected_df, + pd.DataFrame.from_dict( + { + "comment_text": ["comment 1"], + "toxicity": [0.0], + "gender": [[]], + "sexual_orientation": [["bisexual"]], + "race": [["other_race_or_ethnicity"]], + "religion": [["atheist", "other_religion"]], + "disability": [ + [ + "physical_disability", + "intellectual_or_learning_disability", + "psychiatric_or_mental_illness", + "other_disability", + ] + ], + } + ), + ], + ignore_index=True, + ) - self.assertEqual( - df.reset_index(drop=True, inplace=True), - expected_df.reset_index(drop=True, inplace=True)) + self.assertEqual( + df.reset_index(drop=True, inplace=True), + expected_df.reset_index(drop=True, inplace=True), + ) - # TODO(b/172260507): we should also look into testing the e2e call with tfma. - @mock.patch( - 'tensorflow_model_analysis.default_eval_shared_model', autospec=True) - @mock.patch('tensorflow_model_analysis.run_model_analysis', autospec=True) - def test_get_eval_results_called_correclty(self, mock_run_model_analysis, - mock_shared_model): - mock_model = 'model' - mock_shared_model.return_value = mock_model + # TODO(b/172260507): we should also look into testing the e2e call with tfma. + @mock.patch("tensorflow_model_analysis.default_eval_shared_model", autospec=True) + @mock.patch("tensorflow_model_analysis.run_model_analysis", autospec=True) + def test_get_eval_results_called_correclty( + self, mock_run_model_analysis, mock_shared_model + ): + mock_model = "model" + mock_shared_model.return_value = mock_model - model_location = 'saved_model' - eval_results_path = 'eval_results' - data_file = 'data' - util.get_eval_results(model_location, eval_results_path, data_file) + model_location = "saved_model" + eval_results_path = "eval_results" + data_file = "data" + util.get_eval_results(model_location, eval_results_path, data_file) - mock_shared_model.assert_called_once_with( - eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING]) + mock_shared_model.assert_called_once_with( + eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING] + ) - expected_eval_config = text_format.Parse( - """ + expected_eval_config = text_format.Parse( + """ model_specs { label_key: 'toxicity' } @@ -314,12 +328,15 @@ def test_get_eval_results_called_correclty(self, mock_run_model_analysis, compute_confidence_intervals { value: true } disabled_outputs{values: "analysis"} } - """, tfma.EvalConfig()) + """, + tfma.EvalConfig(), + ) - mock_run_model_analysis.assert_called_once_with( - eval_shared_model=mock_model, - data_location=data_file, - file_format='tfrecords', - eval_config=expected_eval_config, - output_path=eval_results_path, - extractors=None) + mock_run_model_analysis.assert_called_once_with( + eval_shared_model=mock_model, + data_location=data_file, + file_format="tfrecords", + eval_config=expected_eval_config, + output_path=eval_results_path, + extractors=None, + ) diff --git a/fairness_indicators/version.py b/fairness_indicators/version.py index 6882198..0f30712 100644 --- a/fairness_indicators/version.py +++ b/fairness_indicators/version.py @@ -14,4 +14,4 @@ """Contains the version string of Fairness Indicators.""" # Note that setup.py uses this version. -__version__ = '0.49.0.dev' +__version__ = "0.49.0.dev" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a862155 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,115 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = [ + "setuptools", + "wheel", +] + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # pep8 naming + "N", + # pydocstyle + "D", + # annotations + "ANN", + # debugger + "T10", + # flake8-pytest + "PT", + # flake8-return + "RET", + # flake8-unused-arguments + "ARG", + # flake8-fixme + "FIX", + # flake8-eradicate + "ERA", + # pandas-vet + "PD", + # numpy-specific rules + "NPY", +] + +ignore = [ + "D104", # Missing docstring in public package + "D100", # Missing docstring in public module + "D211", # No blank line before class + "PD901", # Avoid using 'df' for pandas dataframes. Perfectly fine in functions with limited scope + "ANN201", # Missing return type annotation for public function (makes no sense for NoneType return types...) + "ANN101", # Missing type annotation for `self` + "ANN204", # Missing return type annotation for special method + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "D105", # Missing docstring in magic method + "D203", # 1 blank line before after class docstring + "D204", # 1 blank line required after class docstring + "D413", # 1 blank line after parameters + "SIM108", # Simplify if/else to one line; not always clearer + "D206", # Docstrings should be indented with spaces; unnecessary when running ruff-format + "E501", # Line length too long; unnecessary when running ruff-format + "W191", # Indentation contains tabs; unnecessary when running ruff-format + + # REMOVE THESE AS FIXED + "ANN001", # Missing type annotation for function argument + "ANN202", # Missing return type annotation for private function + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG001", # Unused function argument + "ARG002", # Unused method argument + "B018", # Found useless expression + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D107", # Missing docstring in `__init__` + "D401", # First line of docstring should be in imperative mood + "ERA001", # Found commented-out code + "FIX002", # Line contains TODO + "N802", # Function name should be lowercase + "PD002", # `inplace=True` should be avoided + "PD004", # `.notna` is preferred to `.notnull` + "PT009", # Use a regular `assert` instead of unittest-style + "PT027", # Use `pytest.raises` instead of unittest-style `assertRaises` + "RET505", # Unnecessary `elif` after `return` statement + "RET506", # Unnecessary `else` after `raise` statement + "SIM105", # Use `contextlib.suppress` instead of `try`-`except`-`pass` + "UP008", # Use `super()` instead of `super(__class__, self)` + "UP031", # Use format specifiers instead of percent format +] + + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.pytest.ini_options] +addopts = "--import-mode=importlib" +testpaths = ["fairness_indicators"] +python_files = ["*_test.py"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index d73a83e..0000000 --- a/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -addopts = "--import-mode=importlib" -testpaths = "fairness_indicators" -python_files = "*_test.py" \ No newline at end of file diff --git a/setup.py b/setup.py index 2c4fc40..d01b6b1 100644 --- a/setup.py +++ b/setup.py @@ -15,34 +15,35 @@ """Setup to install Fairness Indicators.""" import os -from pathlib import Path import sys +from pathlib import Path import setuptools - if sys.version_info >= (3, 11): - sys.exit('Sorry, Python >= 3.11 is not supported') + sys.exit("Sorry, Python >= 3.11 is not supported") def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - elif selector == 'NIGHTLY' and nightly is not None: - return nightly - elif selector == 'GIT_MASTER' and git_master is not None: - return git_master - else: - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + elif selector == "NIGHTLY" and nightly is not None: + return nightly + elif selector == "GIT_MASTER" and git_master is not None: + return git_master + else: + return default + + REQUIRED_PACKAGES = [ - 'tensorflow>=2.17,<2.18', - 'tensorflow-hub>=0.16.1,<1.0.0', - 'tensorflow-data-validation>=1.17.0,<2.0.0', - 'tensorflow-model-analysis>=0.48.0,<0.49.0', - 'witwidget>=1.4.4,<2', - 'protobuf>=4.21.6,<6.0.0', + "tensorflow>=2.17,<2.18", + "tensorflow-hub>=0.16.1,<1.0.0", + "tensorflow-data-validation>=1.17.0,<2.0.0", + "tensorflow-model-analysis>=0.48.0,<0.49.0", + "witwidget>=1.4.4,<2", + "protobuf>=4.21.6,<6.0.0", ] TEST_PACKAGES = [ @@ -53,53 +54,49 @@ def select_constraint(default, nightly=None, git_master=None): DOCS_PACKAGES = [req.strip() for req in f.readlines()] # Get version from version module. -with open('fairness_indicators/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['__version__'] -with open('README.md', 'r', encoding='utf-8') as fh: - long_description = fh.read() +with open("fairness_indicators/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["__version__"] +with open("README.md", encoding="utf-8") as fh: + long_description = fh.read() setuptools.setup( - name='fairness_indicators', + name="fairness_indicators", version=__version__, - description='Fairness Indicators', + description="Fairness Indicators", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/tensorflow/fairness-indicators', - author='Google LLC', - author_email='packages@tensorflow.org', - packages=setuptools.find_packages(exclude=['tensorboard_plugin']), + long_description_content_type="text/markdown", + url="https://github.com/tensorflow/fairness-indicators", + author="Google LLC", + author_email="packages@tensorflow.org", + packages=setuptools.find_packages(exclude=["tensorboard_plugin"]), package_data={ - 'fairness_indicators': ['documentation/*'], + "fairness_indicators": ["documentation/*"], }, - python_requires='>=3.9,<4', + python_requires=">=3.9,<4", install_requires=REQUIRED_PACKAGES, tests_require=REQUIRED_PACKAGES, - extras_require={ - "docs": DOCS_PACKAGES, - "test": TEST_PACKAGES, - }, + extras_require={"docs": DOCS_PACKAGES, "test": TEST_PACKAGES, "dev": "pre-commit"}, # PyPI package information. classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], - license='Apache 2.0', + license="Apache 2.0", keywords=( - 'tensorflow model analysis fairness indicators tensorboard machine' - ' learning' + "tensorflow model analysis fairness indicators tensorboard machine" " learning" ), ) diff --git a/tensorboard_plugin/pytest.ini b/tensorboard_plugin/pytest.ini index bde2b78..61d2b34 100644 --- a/tensorboard_plugin/pytest.ini +++ b/tensorboard_plugin/pytest.ini @@ -1,4 +1,4 @@ [pytest] addopts = "--import-mode=importlib" testpaths = "tensorboard_plugin_fairness_indicators" -python_files = "*_test.py" \ No newline at end of file +python_files = "*_test.py" diff --git a/tensorboard_plugin/setup.py b/tensorboard_plugin/setup.py index d87b926..b19ef09 100644 --- a/tensorboard_plugin/setup.py +++ b/tensorboard_plugin/setup.py @@ -14,96 +14,91 @@ # ============================================================================== """Setup to install Fairness Indicators Tensorboard plugin.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import os import sys -from setuptools import find_packages -from setuptools import setup - +from setuptools import find_packages, setup if sys.version_info >= (3, 11): - sys.exit('Sorry, Python >= 3.11 is not supported') + sys.exit("Sorry, Python >= 3.11 is not supported") def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - elif selector == 'NIGHTLY' and nightly is not None: - return nightly - elif selector == 'GIT_MASTER' and git_master is not None: - return git_master - else: - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + elif selector == "NIGHTLY" and nightly is not None: + return nightly + elif selector == "GIT_MASTER" and git_master is not None: + return git_master + else: + return default + REQUIRED_PACKAGES = [ - 'protobuf>=4.21.6,<6.0.0', - 'tensorboard>=2.17.0,<2.18.0', - 'tensorflow>=2.17,<2.18', - 'tf-keras>=2.17,<2.18', - 'tensorflow-model-analysis>=0.48,<0.49', - 'werkzeug<2', + "protobuf>=4.21.6,<6.0.0", + "tensorboard>=2.17.0,<2.18.0", + "tensorflow>=2.17,<2.18", + "tf-keras>=2.17,<2.18", + "tensorflow-model-analysis>=0.48,<0.49", + "werkzeug<2", ] TEST_PACKAGES = [ - 'pytest>=8.3.0,<9', + "pytest>=8.3.0,<9", ] -with open('README.md', 'r', encoding='utf-8') as fh: - long_description = fh.read() +with open("README.md", encoding="utf-8") as fh: + long_description = fh.read() # Get version from version module. -with open('tensorboard_plugin_fairness_indicators/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['__version__'] +with open("tensorboard_plugin_fairness_indicators/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["__version__"] setup( - name='tensorboard_plugin_fairness_indicators', + name="tensorboard_plugin_fairness_indicators", version=__version__, - description='Fairness Indicators TensorBoard Plugin', + description="Fairness Indicators TensorBoard Plugin", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/tensorflow/fairness-indicators', - author='Google LLC', - author_email='packages@tensorflow.org', + long_description_content_type="text/markdown", + url="https://github.com/tensorflow/fairness-indicators", + author="Google LLC", + author_email="packages@tensorflow.org", packages=find_packages(), package_data={ - 'tensorboard_plugin_fairness_indicators': ['static/**'], + "tensorboard_plugin_fairness_indicators": ["static/**"], }, entry_points={ - 'tensorboard_plugins': [ - 'fairness_indicators = tensorboard_plugin_fairness_indicators.plugin:FairnessIndicatorsPlugin', + "tensorboard_plugins": [ + "fairness_indicators = tensorboard_plugin_fairness_indicators.plugin:FairnessIndicatorsPlugin", ], }, - python_requires='>=3.9,<4', + python_requires=">=3.9,<4", install_requires=REQUIRED_PACKAGES, tests_require=REQUIRED_PACKAGES, extras_require={ - 'test': TEST_PACKAGES, + "test": TEST_PACKAGES, }, classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], - license='Apache 2.0', - keywords='tensorflow model analysis fairness indicators tensorboard machine learning', + license="Apache 2.0", + keywords="tensorflow model analysis fairness indicators tensorboard machine learning", ) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/demo.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/demo.py index 0c61354..d78fdd1 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/demo.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/demo.py @@ -14,34 +14,31 @@ # ============================================================================== """Fairness Indicators Plugin Demo.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags -from tensorboard_plugin_fairness_indicators import summary_v2 import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 +from absl import app, flags + +from tensorboard_plugin_fairness_indicators import summary_v2 tf.enable_eager_execution() tf = tf2 FLAGS = flags.FLAGS -flags.DEFINE_string('eval_result_output_dir', '', - 'Log dir containing evaluation results.') +flags.DEFINE_string( + "eval_result_output_dir", "", "Log dir containing evaluation results." +) -flags.DEFINE_string('logdir', '', 'Log dir where demo logs will be written.') +flags.DEFINE_string("logdir", "", "Log dir where demo logs will be written.") def main(unused_argv): - writer = tf.summary.create_file_writer(FLAGS.logdir) + writer = tf.summary.create_file_writer(FLAGS.logdir) - with writer.as_default(): - summary_v2.FairnessIndicators(FLAGS.eval_result_output_dir, step=1) - writer.close() + with writer.as_default(): + summary_v2.FairnessIndicators(FLAGS.eval_result_output_dir, step=1) + writer.close() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata.py index 071f2d8..c3fcc5f 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata.py @@ -14,17 +14,13 @@ # ============================================================================== """Plugin-specific global metadata.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - from tensorboard.compat.proto import summary_pb2 PLUGIN_NAME = "fairness_indicators" def CreateSummaryMetadata(description=None): - return summary_pb2.SummaryMetadata( - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME)) + return summary_pb2.SummaryMetadata( + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME), + ) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata_test.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata_test.py index 16e3a3c..cb5cb04 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata_test.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/metadata_test.py @@ -14,23 +14,17 @@ # ============================================================================== """Tests for util function to create plugin metadata.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import tensorflow.compat.v1 as tf from tensorboard_plugin_fairness_indicators import metadata -import tensorflow.compat.v1 as tf class MetadataTest(tf.test.TestCase): + def testCreateSummaryMetadata(self): + summary_metadata = metadata.CreateSummaryMetadata("description") + self.assertEqual(metadata.PLUGIN_NAME, summary_metadata.plugin_data.plugin_name) + self.assertEqual("description", summary_metadata.summary_description) - def testCreateSummaryMetadata(self): - summary_metadata = metadata.CreateSummaryMetadata('description') - self.assertEqual(metadata.PLUGIN_NAME, - summary_metadata.plugin_data.plugin_name) - self.assertEqual('description', summary_metadata.summary_description) - - def testCreateSummaryMetadata_withoutDescription(self): - summary_metadata = metadata.CreateSummaryMetadata() - self.assertEqual(metadata.PLUGIN_NAME, - summary_metadata.plugin_data.plugin_name) + def testCreateSummaryMetadata_withoutDescription(self): + summary_metadata = metadata.CreateSummaryMetadata() + self.assertEqual(metadata.PLUGIN_NAME, summary_metadata.plugin_data.plugin_name) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py index b03a695..a633582 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py @@ -14,78 +14,75 @@ # ============================================================================== """TensorBoard Fairnss Indicators plugin.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import os from typing import Any, Union -from absl import logging -from tensorboard_plugin_fairness_indicators import metadata import six import tensorflow as tf import tensorflow_model_analysis as tfma -from werkzeug import wrappers - +from absl import logging from google.protobuf import json_format from tensorboard.backend import http_util from tensorboard.plugins import base_plugin +from werkzeug import wrappers +from tensorboard_plugin_fairness_indicators import metadata _TEMPLATE_LOCATION = os.path.normpath( os.path.join( - __file__, '../../' - 'tensorflow_model_analysis/static/vulcanized_tfma.js')) + __file__, "../../" "tensorflow_model_analysis/static/vulcanized_tfma.js" + ) +) def stringify_slice_key_value( slice_key: tfma.slicer.slicer_lib.SliceKeyType, ) -> str: - """Stringifies a slice key value. - - The string representation of a SingletonSliceKeyType is "feature:value". This - function returns value. - - When - multiple columns / features are specified, the string representation of a - SliceKeyType's value is "v1_X_v2_X_..." where v1, v2, ... are values. For - example, - ('gender, 'f'), ('age', 5) becomes f_X_5. If no columns / feature - specified, return "Overall". - - Note that we do not perform special escaping for slice values that contain - '_X_'. This stringified representation is meant to be human-readbale rather - than a reversible encoding. - - The columns will be in the same order as in SliceKeyType. If they are - generated using SingleSliceSpec.generate_slices, they will be in sorted order, - ascending. - - Technically float values are not supported, but we don't check for them here. - - Args: - slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes - should be sorted in ascending order. - - Returns: - String representation of the slice key's value. - """ - if not slice_key: - return 'Overall' - - # Since this is meant to be a human-readable string, we assume that the - # feature values are valid UTF-8 strings (might not be true in cases where - # people store serialised protos in the features for instance). - # We need to call as_str_any to convert non-string (e.g. integer) values to - # string first before converting to text. - # We use u'{}' instead of '{}' here to avoid encoding a unicode character with - # ascii codec. - values = [ - '{}'.format(tf.compat.as_text(tf.compat.as_str_any(value))) - for _, value in slice_key - ] - return '_X_'.join(values) + """Stringifies a slice key value. + + The string representation of a SingletonSliceKeyType is "feature:value". This + function returns value. + + When + multiple columns / features are specified, the string representation of a + SliceKeyType's value is "v1_X_v2_X_..." where v1, v2, ... are values. For + example, + ('gender, 'f'), ('age', 5) becomes f_X_5. If no columns / feature + specified, return "Overall". + + Note that we do not perform special escaping for slice values that contain + '_X_'. This stringified representation is meant to be human-readbale rather + than a reversible encoding. + + The columns will be in the same order as in SliceKeyType. If they are + generated using SingleSliceSpec.generate_slices, they will be in sorted order, + ascending. + + Technically float values are not supported, but we don't check for them here. + + Args: + ---- + slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes + should be sorted in ascending order. + + Returns: + ------- + String representation of the slice key's value. + """ + if not slice_key: + return "Overall" + + # Since this is meant to be a human-readable string, we assume that the + # feature values are valid UTF-8 strings (might not be true in cases where + # people store serialised protos in the features for instance). + # We need to call as_str_any to convert non-string (e.g. integer) values to + # string first before converting to text. + # We use u'{}' instead of '{}' here to avoid encoding a unicode character with + # ascii codec. + values = [ + f"{tf.compat.as_text(tf.compat.as_str_any(value))}" for _, value in slice_key + ] + return "_X_".join(values) def _add_cross_slice_key_data( @@ -93,31 +90,35 @@ def _add_cross_slice_key_data( metrics: tfma.view.view_types.MetricsByTextKey, data: list[Any], ): - """Adds data for cross slice key. - - Baseline and comparison slice keys are joined by '__XX__'. - Args: - slice_key: Cross slice key. - metrics: Metrics data for the cross slice key. - data: List where UI data is to be appended. - """ - baseline_key = slice_key[0] - comparison_key = slice_key[1] - stringify_slice_value = ( - stringify_slice_key_value(baseline_key) - + '__XX__' - + stringify_slice_key_value(comparison_key) - ) - stringify_slice = ( - tfma.slicer.slicer_lib.stringify_slice_key(baseline_key) - + '__XX__' - + tfma.slicer.slicer_lib.stringify_slice_key(comparison_key) - ) - data.append({ - 'sliceValue': stringify_slice_value, - 'slice': stringify_slice, - 'metrics': metrics, - }) + """Adds data for cross slice key. + + Baseline and comparison slice keys are joined by '__XX__'. + + Args: + ---- + slice_key: Cross slice key. + metrics: Metrics data for the cross slice key. + data: List where UI data is to be appended. + """ + baseline_key = slice_key[0] + comparison_key = slice_key[1] + stringify_slice_value = ( + stringify_slice_key_value(baseline_key) + + "__XX__" + + stringify_slice_key_value(comparison_key) + ) + stringify_slice = ( + tfma.slicer.slicer_lib.stringify_slice_key(baseline_key) + + "__XX__" + + tfma.slicer.slicer_lib.stringify_slice_key(comparison_key) + ) + data.append( + { + "sliceValue": stringify_slice_value, + "slice": stringify_slice, + "metrics": metrics, + } + ) def convert_slicing_metrics_to_ui_input( @@ -129,185 +130,195 @@ def convert_slicing_metrics_to_ui_input( ], slicing_column: Union[str, None] = None, slicing_spec: Union[tfma.slicer.slicer_lib.SingleSliceSpec, None] = None, - output_name: str = '', - multi_class_key: str = '', + output_name: str = "", + multi_class_key: str = "", ) -> Union[list[dict[str, Any]], None]: - """Renders the Fairness Indicator view. - - Args: - slicing_metrics: tfma.EvalResult.slicing_metrics. - slicing_column: The slicing column to to filter results. If both - slicing_column and slicing_spec are None, show all eval results. - slicing_spec: The slicing spec to filter results. If both slicing_column and - slicing_spec are None, show all eval results. - output_name: The output name associated with metric (for multi-output - models). - multi_class_key: The multi-class key associated with metric (for multi-class - models). - - Returns: - A list of dicts for each slice, where each dict contains keys 'sliceValue', - 'slice', and 'metrics'. - - Raises: - ValueError if no related eval result found or both slicing_column and - slicing_spec are not None. - """ - if slicing_column and slicing_spec: - raise ValueError( - 'Only one of the "slicing_column" and "slicing_spec" parameters ' - 'can be set.' - ) - if slicing_column: - slicing_spec = tfma.slicer.slicer_lib.SingleSliceSpec( - columns=[slicing_column] - ) - - data = [] - for slice_key, metric_value in slicing_metrics: - if ( - metric_value is not None - and output_name in metric_value - and multi_class_key in metric_value[output_name] - ): - metrics = metric_value[output_name][multi_class_key] - # To add evaluation data for cross slice comparison. - if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key): - _add_cross_slice_key_data(slice_key, metrics, data) - # To add evaluation data for regular slices. - elif ( - slicing_spec is None - or not slice_key - or slicing_spec.is_slice_applicable(slice_key) - ): - data.append({ - 'sliceValue': stringify_slice_key_value(slice_key), - 'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key), - 'metrics': metrics, - }) - if not data: - raise ValueError( - 'No eval result found for output_name:"%s" and ' - 'multi_class_key:"%s" and slicing_column:"%s" and slicing_spec:"%s".' - % (output_name, multi_class_key, slicing_column, slicing_spec) - ) - return data - - -class FairnessIndicatorsPlugin(base_plugin.TBPlugin): - """A plugin to visualize Fairness Indicators.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates plugin via TensorBoard core. + """Renders the Fairness Indicator view. Args: - context: A base_plugin.TBContext instance. A magic container that - TensorBoard uses to make objects available to the plugin. - """ - self._multiplexer = context.multiplexer - - def get_plugin_apps(self): - """Gets all routes offered by the plugin. - - This method is called by TensorBoard when retrieving all the - routes offered by the plugin. + ---- + slicing_metrics: tfma.EvalResult.slicing_metrics. + slicing_column: The slicing column to to filter results. If both + slicing_column and slicing_spec are None, show all eval results. + slicing_spec: The slicing spec to filter results. If both slicing_column and + slicing_spec are None, show all eval results. + output_name: The output name associated with metric (for multi-output + models). + multi_class_key: The multi-class key associated with metric (for multi-class + models). Returns: - A dictionary mapping URL path to route that handles it. + ------- + A list of dicts for each slice, where each dict contains keys 'sliceValue', + 'slice', and 'metrics'. + + Raises: + ------ + ValueError if no related eval result found or both slicing_column and + slicing_spec are not None. """ - return { - '/get_evaluation_result': - self._get_evaluation_result, - '/get_evaluation_result_from_remote_path': - self._get_evaluation_result_from_remote_path, - '/index.js': - self._serve_js, - '/vulcanized_tfma.js': - self._serve_vulcanized_js, - } - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - es_module_path='/index.js', - disable_reload=False, - tab_name='Fairness Indicators', - remove_dom=False, - element_name=None) - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is only active if TensorBoard sampled any summaries - relevant to the plugin. - - Returns: - Whether this plugin is active. - """ - return bool( - self._multiplexer.PluginRunToTagToContent( - FairnessIndicatorsPlugin.plugin_name)) - - # pytype: disable=wrong-arg-types - @wrappers.Request.application - def _serve_js(self, request): - filepath = os.path.join(os.path.dirname(__file__), 'static', 'index.js') - with open(filepath) as infile: - contents = infile.read() - return http_util.Respond( - request, contents, content_type='application/javascript') - - @wrappers.Request.application - def _serve_vulcanized_js(self, request): - with open(_TEMPLATE_LOCATION) as infile: - contents = infile.read() - return http_util.Respond( - request, contents, content_type='application/javascript') - - @wrappers.Request.application - def _get_evaluation_result(self, request): - run = request.args.get('run') - try: - run = six.ensure_text(run) - except (UnicodeDecodeError, AttributeError): - pass + if slicing_column and slicing_spec: + raise ValueError( + 'Only one of the "slicing_column" and "slicing_spec" parameters ' + "can be set." + ) + if slicing_column: + slicing_spec = tfma.slicer.slicer_lib.SingleSliceSpec(columns=[slicing_column]) data = [] - try: - eval_result_output_dir = six.ensure_text( - self._multiplexer.Tensors(run, FairnessIndicatorsPlugin.plugin_name) - [0].tensor_proto.string_val[0]) - eval_result = tfma.load_eval_result(output_path=eval_result_output_dir) - # TODO(b/141283811): Allow users to choose different model output names - # and class keys in case of multi-output and multi-class model. - data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) - except (KeyError, json_format.ParseError) as error: - logging.info('Error while fetching evaluation data, %s', error) - return http_util.Respond(request, data, content_type='application/json') - - def _get_output_file_format(self, evaluation_output_path): - file_format = os.path.splitext(evaluation_output_path)[1] - if file_format: - return file_format[1:] - - return '' - - @wrappers.Request.application - def _get_evaluation_result_from_remote_path(self, request): - evaluation_output_path = request.args.get('evaluation_output_path') - try: - evaluation_output_path = six.ensure_text(evaluation_output_path) - except (UnicodeDecodeError, AttributeError): - pass - try: - eval_result = tfma.load_eval_result( - os.path.dirname(evaluation_output_path), - output_file_format=self._get_output_file_format( - evaluation_output_path)) - data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) - except (KeyError, json_format.ParseError) as error: - logging.info('Error while fetching evaluation data, %s', error) - data = [] - return http_util.Respond(request, data, content_type='application/json') - # pytype: enable=wrong-arg-types + for slice_key, metric_value in slicing_metrics: + if ( + metric_value is not None + and output_name in metric_value + and multi_class_key in metric_value[output_name] + ): + metrics = metric_value[output_name][multi_class_key] + # To add evaluation data for cross slice comparison. + if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key): + _add_cross_slice_key_data(slice_key, metrics, data) + # To add evaluation data for regular slices. + elif ( + slicing_spec is None + or not slice_key + or slicing_spec.is_slice_applicable(slice_key) + ): + data.append( + { + "sliceValue": stringify_slice_key_value(slice_key), + "slice": tfma.slicer.slicer_lib.stringify_slice_key(slice_key), + "metrics": metrics, + } + ) + if not data: + raise ValueError( + 'No eval result found for output_name:"%s" and ' + 'multi_class_key:"%s" and slicing_column:"%s" and slicing_spec:"%s".' + % (output_name, multi_class_key, slicing_column, slicing_spec) + ) + return data + + +class FairnessIndicatorsPlugin(base_plugin.TBPlugin): + """A plugin to visualize Fairness Indicators.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates plugin via TensorBoard core. + + Args: + ---- + context: A base_plugin.TBContext instance. A magic container that + TensorBoard uses to make objects available to the plugin. + """ + self._multiplexer = context.multiplexer + + def get_plugin_apps(self): + """Gets all routes offered by the plugin. + + This method is called by TensorBoard when retrieving all the + routes offered by the plugin. + + Returns + ------- + A dictionary mapping URL path to route that handles it. + """ + return { + "/get_evaluation_result": self._get_evaluation_result, + "/get_evaluation_result_from_remote_path": self._get_evaluation_result_from_remote_path, + "/index.js": self._serve_js, + "/vulcanized_tfma.js": self._serve_vulcanized_js, + } + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + es_module_path="/index.js", + disable_reload=False, + tab_name="Fairness Indicators", + remove_dom=False, + element_name=None, + ) + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is only active if TensorBoard sampled any summaries + relevant to the plugin. + + Returns + ------- + Whether this plugin is active. + """ + return bool( + self._multiplexer.PluginRunToTagToContent( + FairnessIndicatorsPlugin.plugin_name + ) + ) + + # pytype: disable=wrong-arg-types + @wrappers.Request.application + def _serve_js(self, request): + filepath = os.path.join(os.path.dirname(__file__), "static", "index.js") + with open(filepath) as infile: + contents = infile.read() + return http_util.Respond( + request, contents, content_type="application/javascript" + ) + + @wrappers.Request.application + def _serve_vulcanized_js(self, request): + with open(_TEMPLATE_LOCATION) as infile: + contents = infile.read() + return http_util.Respond( + request, contents, content_type="application/javascript" + ) + + @wrappers.Request.application + def _get_evaluation_result(self, request): + run = request.args.get("run") + try: + run = six.ensure_text(run) + except (UnicodeDecodeError, AttributeError): + pass + + data = [] + try: + eval_result_output_dir = six.ensure_text( + self._multiplexer.Tensors(run, FairnessIndicatorsPlugin.plugin_name)[ + 0 + ].tensor_proto.string_val[0] + ) + eval_result = tfma.load_eval_result(output_path=eval_result_output_dir) + # TODO(b/141283811): Allow users to choose different model output names + # and class keys in case of multi-output and multi-class model. + data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) + except (KeyError, json_format.ParseError) as error: + logging.info("Error while fetching evaluation data, %s", error) + return http_util.Respond(request, data, content_type="application/json") + + def _get_output_file_format(self, evaluation_output_path): + file_format = os.path.splitext(evaluation_output_path)[1] + if file_format: + return file_format[1:] + + return "" + + @wrappers.Request.application + def _get_evaluation_result_from_remote_path(self, request): + evaluation_output_path = request.args.get("evaluation_output_path") + try: + evaluation_output_path = six.ensure_text(evaluation_output_path) + except (UnicodeDecodeError, AttributeError): + pass + try: + eval_result = tfma.load_eval_result( + os.path.dirname(evaluation_output_path), + output_file_format=self._get_output_file_format(evaluation_output_path), + ) + data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) + except (KeyError, json_format.ParseError) as error: + logging.info("Error while fetching evaluation data, %s", error) + data = [] + return http_util.Respond(request, data, content_type="application/json") + + # pytype: enable=wrong-arg-types diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py index 9c82c5c..d993580 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py @@ -13,94 +13,94 @@ # limitations under the License. # ============================================================================== """Tests the Tensorboard Fairness Indicators plugin.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from collections import abc import os -import pytest import shutil +from collections import abc from unittest import mock -from tensorboard_plugin_fairness_indicators import plugin -from tensorboard_plugin_fairness_indicators import summary_v2 +import pytest import six import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 import tensorflow_model_analysis as tfma +from google.protobuf import text_format +from tensorboard.backend import application +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) +from tensorboard.plugins import base_plugin from tensorflow_model_analysis.utils import example_keras_model from werkzeug import test as werkzeug_test from werkzeug import wrappers -from google.protobuf import text_format -from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer -from tensorboard.plugins import base_plugin +from tensorboard_plugin_fairness_indicators import plugin, summary_v2 tf.enable_eager_execution() tf = tf2 class PluginTest(tf.test.TestCase): - """Tests for Fairness Indicators plugin server.""" - - def setUp(self): - super(PluginTest, self).setUp() - # Log dir to save temp events into. - self._log_dir = self.get_temp_dir() - self._eval_result_output_dir = os.path.join(self.get_temp_dir(), - "eval_result") - if not os.path.isdir(self._eval_result_output_dir): - os.mkdir(self._eval_result_output_dir) - - writer = tf.summary.create_file_writer(self._log_dir) - - with writer.as_default(): - summary_v2.FairnessIndicators(self._eval_result_output_dir, step=1) - writer.close() - - # Start a server that will receive requests. - self._multiplexer = event_multiplexer.EventMultiplexer({ - ".": self._log_dir, - }) - self._context = base_plugin.TBContext( - logdir=self._log_dir, multiplexer=self._multiplexer) - self._plugin = plugin.FairnessIndicatorsPlugin(self._context) - self._multiplexer.Reload() - wsgi_app = application.TensorBoardWSGI([self._plugin]) - self._server = werkzeug_test.Client(wsgi_app, wrappers.Response) - self._routes = self._plugin.get_plugin_apps() - - def tearDown(self): - super(PluginTest, self).tearDown() - shutil.rmtree(self._log_dir, ignore_errors=True) - - def _export_keras_model(self, classifier): - temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir") - classifier.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse") - tf.saved_model.save(classifier, temp_eval_export_dir) - return temp_eval_export_dir - - def _write_tf_examples_to_tfrecords(self, examples): - data_location = os.path.join(self.get_temp_dir(), "input_data.rio") - with tf.io.TFRecordWriter(data_location) as writer: - for example in examples: - writer.write(example.SerializeToString()) - return data_location - - def _make_example(self, age, language, label): - example = tf.train.Example() - example.features.feature["age"].float_list.value[:] = [age] - example.features.feature["language"].bytes_list.value[:] = [ - six.ensure_binary(language, "utf8") - ] - example.features.feature["label"].float_list.value[:] = [label] - return example - - def _make_eval_config(self): - return text_format.Parse( - """ + """Tests for Fairness Indicators plugin server.""" + + def setUp(self): + super(PluginTest, self).setUp() + # Log dir to save temp events into. + self._log_dir = self.get_temp_dir() + self._eval_result_output_dir = os.path.join(self.get_temp_dir(), "eval_result") + if not os.path.isdir(self._eval_result_output_dir): + os.mkdir(self._eval_result_output_dir) + + writer = tf.summary.create_file_writer(self._log_dir) + + with writer.as_default(): + summary_v2.FairnessIndicators(self._eval_result_output_dir, step=1) + writer.close() + + # Start a server that will receive requests. + self._multiplexer = event_multiplexer.EventMultiplexer( + { + ".": self._log_dir, + } + ) + self._context = base_plugin.TBContext( + logdir=self._log_dir, multiplexer=self._multiplexer + ) + self._plugin = plugin.FairnessIndicatorsPlugin(self._context) + self._multiplexer.Reload() + wsgi_app = application.TensorBoardWSGI([self._plugin]) + self._server = werkzeug_test.Client(wsgi_app, wrappers.Response) + self._routes = self._plugin.get_plugin_apps() + + def tearDown(self): + super(PluginTest, self).tearDown() + shutil.rmtree(self._log_dir, ignore_errors=True) + + def _export_keras_model(self, classifier): + temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir") + classifier.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse") + tf.saved_model.save(classifier, temp_eval_export_dir) + return temp_eval_export_dir + + def _write_tf_examples_to_tfrecords(self, examples): + data_location = os.path.join(self.get_temp_dir(), "input_data.rio") + with tf.io.TFRecordWriter(data_location) as writer: + for example in examples: + writer.write(example.SerializeToString()) + return data_location + + def _make_example(self, age, language, label): + example = tf.train.Example() + example.features.feature["age"].float_list.value[:] = [age] + example.features.feature["language"].bytes_list.value[:] = [ + six.ensure_binary(language, "utf8") + ] + example.features.feature["label"].float_list.value[:] = [label] + return example + + def _make_eval_config(self): + return text_format.Parse( + """ model_specs { signature_name: "serving_default" prediction_key: "predictions" # placeholder @@ -116,116 +116,111 @@ def _make_eval_config(self): } } """, - tfma.EvalConfig(), - ) + tfma.EvalConfig(), + ) + + def testRoutes(self): + self.assertIsInstance(self._routes["/get_evaluation_result"], abc.Callable) + self.assertIsInstance( + self._routes["/get_evaluation_result_from_remote_path"], abc.Callable + ) + self.assertIsInstance(self._routes["/index.js"], abc.Callable) + self.assertIsInstance(self._routes["/vulcanized_tfma.js"], abc.Callable) - def testRoutes(self): - self.assertIsInstance(self._routes["/get_evaluation_result"], - abc.Callable) - self.assertIsInstance( - self._routes["/get_evaluation_result_from_remote_path"], - abc.Callable) - self.assertIsInstance(self._routes["/index.js"], abc.Callable) - self.assertIsInstance(self._routes["/vulcanized_tfma.js"], - abc.Callable) - - @mock.patch.object( - event_multiplexer.EventMultiplexer, - "PluginRunToTagToContent", - return_value={"bar": { - "foo": "".encode("utf-8") - }}, - ) - def testIsActive(self, get_random_stub): # pylint: disable=unused-argument - self.assertTrue(self._plugin.is_active()) - - @mock.patch.object( - event_multiplexer.EventMultiplexer, - "PluginRunToTagToContent", - return_value={}) - def testIsInactive(self, get_random_stub): # pylint: disable=unused-argument - self.assertFalse(self._plugin.is_active()) - - def testIndexJsRoute(self): - """Tests that the /tags route offers the correct run to tag mapping.""" - response = self._server.get("/data/plugin/fairness_indicators/index.js") - self.assertEqual(200, response.status_code) - - @pytest.mark.xfail( - reason=( - "Failing on `master` as of `942b672457e07ac2ac27de0bcc45a4c80276785c`. " - "Please remove once fixed." - ) - ) - def testVulcanizedTemplateRoute(self): - """Tests that the /tags route offers the correct run to tag mapping.""" - response = self._server.get( - "/data/plugin/fairness_indicators/vulcanized_tfma.js" + @mock.patch.object( + event_multiplexer.EventMultiplexer, + "PluginRunToTagToContent", + return_value={"bar": {"foo": b""}}, ) - self.assertEqual(200, response.status_code) + def testIsActive(self, get_random_stub): # pylint: disable=unused-argument + self.assertTrue(self._plugin.is_active()) - def testGetEvalResultsRoute(self): - model_location = self._export_keras_model( - example_keras_model.get_example_classifier_model( - input_feature_key="language" - ) + @mock.patch.object( + event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={} ) - examples = [ - self._make_example(age=3.0, language="english", label=1.0), - self._make_example(age=3.0, language="chinese", label=0.0), - self._make_example(age=4.0, language="english", label=1.0), - self._make_example(age=5.0, language="chinese", label=1.0), - self._make_example(age=5.0, language="hindi", label=1.0), - ] - eval_config = self._make_eval_config() - data_location = self._write_tf_examples_to_tfrecords(examples) - _ = tfma.run_model_analysis( - eval_shared_model=tfma.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ), - eval_config=eval_config, - data_location=data_location, - output_path=self._eval_result_output_dir, + def testIsInactive(self, get_random_stub): # pylint: disable=unused-argument + self.assertFalse(self._plugin.is_active()) + + def testIndexJsRoute(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + response = self._server.get("/data/plugin/fairness_indicators/index.js") + self.assertEqual(200, response.status_code) + + @pytest.mark.xfail( + reason=( + "Failing on `master` as of `942b672457e07ac2ac27de0bcc45a4c80276785c`. " + "Please remove once fixed." + ) ) + def testVulcanizedTemplateRoute(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + response = self._server.get( + "/data/plugin/fairness_indicators/vulcanized_tfma.js" + ) + self.assertEqual(200, response.status_code) - response = self._server.get( - "/data/plugin/fairness_indicators/get_evaluation_result?run=." - ) - self.assertEqual(200, response.status_code) + def testGetEvalResultsRoute(self): + model_location = self._export_keras_model( + example_keras_model.get_example_classifier_model( + input_feature_key="language" + ) + ) + examples = [ + self._make_example(age=3.0, language="english", label=1.0), + self._make_example(age=3.0, language="chinese", label=0.0), + self._make_example(age=4.0, language="english", label=1.0), + self._make_example(age=5.0, language="chinese", label=1.0), + self._make_example(age=5.0, language="hindi", label=1.0), + ] + eval_config = self._make_eval_config() + data_location = self._write_tf_examples_to_tfrecords(examples) + _ = tfma.run_model_analysis( + eval_shared_model=tfma.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + eval_config=eval_config, + data_location=data_location, + output_path=self._eval_result_output_dir, + ) - def testGetEvalResultsFromURLRoute(self): - model_location = self._export_keras_model( - example_keras_model.get_example_classifier_model( - input_feature_key="language" + response = self._server.get( + "/data/plugin/fairness_indicators/get_evaluation_result?run=." ) - ) - examples = [ - self._make_example(age=3.0, language="english", label=1.0), - self._make_example(age=3.0, language="chinese", label=0.0), - self._make_example(age=4.0, language="english", label=1.0), - self._make_example(age=5.0, language="chinese", label=1.0), - self._make_example(age=5.0, language="hindi", label=1.0), - ] - eval_config = self._make_eval_config() - data_location = self._write_tf_examples_to_tfrecords(examples) - _ = tfma.run_model_analysis( - eval_shared_model=tfma.default_eval_shared_model( - eval_saved_model_path=model_location, eval_config=eval_config - ), - eval_config=eval_config, - data_location=data_location, - output_path=self._eval_result_output_dir, - ) + self.assertEqual(200, response.status_code) - response = self._server.get( - "/data/plugin/fairness_indicators/" - + "get_evaluation_result_from_remote_path?evaluation_output_path=" - + os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY) - ) - self.assertEqual(200, response.status_code) + def testGetEvalResultsFromURLRoute(self): + model_location = self._export_keras_model( + example_keras_model.get_example_classifier_model( + input_feature_key="language" + ) + ) + examples = [ + self._make_example(age=3.0, language="english", label=1.0), + self._make_example(age=3.0, language="chinese", label=0.0), + self._make_example(age=4.0, language="english", label=1.0), + self._make_example(age=5.0, language="chinese", label=1.0), + self._make_example(age=5.0, language="hindi", label=1.0), + ] + eval_config = self._make_eval_config() + data_location = self._write_tf_examples_to_tfrecords(examples) + _ = tfma.run_model_analysis( + eval_shared_model=tfma.default_eval_shared_model( + eval_saved_model_path=model_location, eval_config=eval_config + ), + eval_config=eval_config, + data_location=data_location, + output_path=self._eval_result_output_dir, + ) - def testGetOutputFileFormat(self): - self.assertEqual("", self._plugin._get_output_file_format("abc_path")) - self.assertEqual( - "tfrecord", self._plugin._get_output_file_format("abc_path.tfrecord") - ) + response = self._server.get( + "/data/plugin/fairness_indicators/" + + "get_evaluation_result_from_remote_path?evaluation_output_path=" + + os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY) + ) + self.assertEqual(200, response.status_code) + + def testGetOutputFileFormat(self): + self.assertEqual("", self._plugin._get_output_file_format("abc_path")) + self.assertEqual( + "tfrecord", self._plugin._get_output_file_format("abc_path.tfrecord") + ) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2.py index ebb59a4..42f0d68 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2.py @@ -14,39 +14,39 @@ # ============================================================================== """Summaries for Fairness Indicators plugin.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from tensorboard.compat import tf2 as tf from tensorboard_plugin_fairness_indicators import metadata -from tensorboard.compat import tf2 as tf def FairnessIndicators(eval_result_output_dir, step=None, description=None): - """Write a Fairness Indicators summary. + """Write a Fairness Indicators summary. - Arguments: - eval_result_output_dir: Directory output created by - tfma.model_eval_lib.ExtractEvaluateAndWriteResults API, which contains - 'metrics' file having MetricsForSlice results. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - description: Optional long-form description for this summary, as a constant - `str`. Markdown is supported. Defaults to empty. + Arguments: + --------- + eval_result_output_dir: Directory output created by + tfma.model_eval_lib.ExtractEvaluateAndWriteResults API, which contains + 'metrics' file having MetricsForSlice results. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was written because no default - summary writer was available. + Returns: + ------- + True on success, or false if no summary was written because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - with tf.summary.experimental.summary_scope(metadata.PLUGIN_NAME): - return tf.summary.write( - tag=metadata.PLUGIN_NAME, - tensor=tf.constant(eval_result_output_dir), - step=step, - metadata=metadata.CreateSummaryMetadata(description), - ) + Raises: + ------ + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + with tf.summary.experimental.summary_scope(metadata.PLUGIN_NAME): + return tf.summary.write( + tag=metadata.PLUGIN_NAME, + tensor=tf.constant(eval_result_output_dir), + step=step, + metadata=metadata.CreateSummaryMetadata(description), + ) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2_test.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2_test.py index 4c50bea..3e48bf1 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2_test.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/summary_v2_test.py @@ -13,57 +13,54 @@ # limitations under the License. # ============================================================================== """Tests for Fairness Indicators summary.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import glob import os -from tensorboard_plugin_fairness_indicators import metadata -from tensorboard_plugin_fairness_indicators import summary_v2 import six import tensorflow.compat.v1 as tf from tensorboard.compat import tf2 +from tensorboard_plugin_fairness_indicators import metadata, summary_v2 + try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.enable_eager_execution() + tf.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class SummaryV2Test(tf.test.TestCase): + def _write_summary(self, eval_result_output_dir): + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary_v2.FairnessIndicators(eval_result_output_dir, step=1) + writer.close() - def _write_summary(self, eval_result_output_dir): - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary_v2.FairnessIndicators(eval_result_output_dir, step=1) - writer.close() - - def _get_event(self): - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - return events[1] + def _get_event(self): + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + return events[1] - def testSummary(self): - self._write_summary('output_dir') - event = self._get_event() + def testSummary(self): + self._write_summary("output_dir") + event = self._get_event() - self.assertEqual(1, event.step) + self.assertEqual(1, event.step) - summary_value = event.summary.value[0] - self.assertEqual(metadata.PLUGIN_NAME, summary_value.tag) - self.assertEqual( - 'output_dir', - six.ensure_text(summary_value.tensor.string_val[0], 'utf-8')) - self.assertEqual(metadata.PLUGIN_NAME, - summary_value.metadata.plugin_data.plugin_name) + summary_value = event.summary.value[0] + self.assertEqual(metadata.PLUGIN_NAME, summary_value.tag) + self.assertEqual( + "output_dir", six.ensure_text(summary_value.tensor.string_val[0], "utf-8") + ) + self.assertEqual( + metadata.PLUGIN_NAME, summary_value.metadata.plugin_data.plugin_name + ) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/version.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/version.py index 026df51..dde7818 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/version.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/version.py @@ -15,4 +15,4 @@ """Contains the version string of Fairness Indicators Tensorboard Plugin.""" # Note that setup.py uses this version. -__version__ = '0.49.0.dev' +__version__ = "0.49.0.dev"