|
23 | 23 |
|
24 | 24 |
|
25 | 25 | import attr
|
26 |
| - |
27 |
| -from keras.saving import save |
28 |
| -from keras.utils import traceback_utils |
29 |
| - |
30 | 26 | import neural_structured_learning.configs as nsl_configs
|
31 | 27 | from neural_structured_learning.lib import adversarial_neighbor
|
32 | 28 | import six
|
@@ -699,58 +695,11 @@ def call(self, inputs, **kwargs):
|
699 | 695 | scaled_adv_loss, name='scaled_adversarial_loss', aggregation='mean')
|
700 | 696 | return outputs
|
701 | 697 |
|
702 |
| - |
703 |
| - @traceback_utils.filter_traceback |
704 |
| - def save(self, |
705 |
| - filepath, |
706 |
| - overwrite=True, |
707 |
| - include_optimizer=True, |
708 |
| - save_format=None, |
709 |
| - signatures=None, |
710 |
| - options=None, |
711 |
| - save_traces=True): |
712 |
| - # pylint: disable=line-too-long |
713 |
| - """Saves the model to Tensorflow SavedModel or a single HDF5 file. |
714 |
| - Please see `tf.keras.models.save_model` or the |
715 |
| - [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) |
716 |
| - for details. |
717 |
| - Args: |
718 |
| - filepath: String, PathLike, path to SavedModel or H5 file to save the |
719 |
| - model. |
720 |
| - overwrite: Whether to silently overwrite any existing file at the |
721 |
| - target location, or provide the user with a manual prompt. |
722 |
| - include_optimizer: If True, save optimizer's state together. |
723 |
| - save_format: Either `'tf'` or `'h5'`, indicating whether to save the |
724 |
| - model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, |
725 |
| - and 'h5' in TF 1.X. |
726 |
| - signatures: Signatures to save with the SavedModel. Applicable to the |
727 |
| - 'tf' format only. Please see the `signatures` argument in |
728 |
| - `tf.saved_model.save` for details. |
729 |
| - options: (only applies to SavedModel format) |
730 |
| - `tf.saved_model.SaveOptions` object that specifies options for |
731 |
| - saving to SavedModel. |
732 |
| - save_traces: (only applies to SavedModel format) When enabled, the |
733 |
| - SavedModel will store the function traces for each layer. This |
734 |
| - can be disabled, so that only the configs of each layer are stored. |
735 |
| - Defaults to `True`. Disabling this will decrease serialization time |
736 |
| - and reduce file size, but it requires that all custom layers/models |
737 |
| - implement a `get_config()` method. |
738 |
| - Example: |
739 |
| - ```python |
740 |
| - from keras.models import load_model |
741 |
| - model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' |
742 |
| - del model # deletes the existing model |
743 |
| - # returns a compiled model |
744 |
| - # identical to the previous one |
745 |
| - model = load_model('my_model.h5') |
746 |
| - ``` |
747 |
| - """ |
748 |
| - # pylint: enable=line-too-long |
749 |
| - save.save_model(self.base_model, filepath, overwrite, include_optimizer, save_format, |
750 |
| - signatures, options, save_traces) |
751 |
| - |
752 |
| - |
753 |
| - |
| 698 | + def save(self, *args, **kwargs): |
| 699 | + """Saves the base model. See base class for details of the interface.""" |
| 700 | + # Adversarial regularization doesn't introduce new model variables, so |
| 701 | + # saving the base model can capture all variables in the model. |
| 702 | + self.base_model.save(*args, **kwargs) |
754 | 703 |
|
755 | 704 | def perturb_on_batch(self, x, **config_kwargs):
|
756 | 705 | """Perturbs the given input to generates adversarial examples.
|
|
0 commit comments