@@ -1904,10 +1904,14 @@ public class TensorFlow {
19041904 /// - outputNames: [String], The names of the function's outputs. Must either have the same length as `outputs` or be null. In the former case, the names should match the regular expression for ArgDef names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will be generated automatically.
19051905 /// - options: various options for the function, e.g. XLA's inlining control.
19061906 /// - description: optional human-readable description of this function
1907- public func toFunction( _ name: String , appendHashToFunctionName: Bool = false , operations: [ Operation ] , inputs: [ Output ] , outputs: [ Output ] , outputNames: [ String ] , options: OpaquePointer ? = nil , description: String = " " ) throws -> Function {
1908- guard outputs. count == outputNames. count else {
1909- throw Panic . FAULT ( reason: " Output array elements are mismatched with names " )
1910- }
1907+ public func toFunction(
1908+ _ name: String , appendHashToFunctionName: Bool = false ,
1909+ operations: [ Operation ] = [ ] ,
1910+ inputs: [ Output ] = [ ] ,
1911+ outputs: [ Output ] = [ ] ,
1912+ outputNames: [ String ] = [ ] ,
1913+ options: OpaquePointer ? = nil ,
1914+ description: String = " " ) throws -> Function {
19111915 let status = try Status ( )
19121916 let opera : UnsafePointer < OpaquePointer ? > ? = operations. map { $0. operation }
19131917 . withUnsafeBufferPointer { $0. baseAddress }
@@ -1929,7 +1933,8 @@ public class TensorFlow {
19291933 Int32 ( outputs. count > 0 ? outputs. count: 0 ) ,
19301934 outputs. count > 0 ? pOutpus : nil ,
19311935
1932- outputs. count > 0 && outputNames. count == outputs. count ? pOutputNames : nil ,
1936+ outputNames. count > 0
1937+ && outputNames. count == outputs. count ? pOutputNames : nil ,
19331938
19341939 options, description. isEmpty ? nil : description,
19351940
@@ -1941,6 +1946,32 @@ public class TensorFlow {
19411946 return Function ( fun)
19421947 }
19431948
1949+ /// Adds a copy of function `func` and optionally its gradient function `grad`
1950+ /// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating
1951+ /// an operation using the function's name.
1952+ /// Any changes to `func`/`grad` (including deleting it) done after this method
1953+ /// returns, won't affect the copy of `func`/`grad` in `g`.
1954+ /// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no
1955+ /// effect on them, but can establish the function->gradient relationship
1956+ /// between them if `func` does not already have a gradient. If `func` already
1957+ /// has a gradient different from `grad`, an error is returned.
1958+ /// If `grad` is null and `func` is not in `g`, `func` is added without a
1959+ /// gradient.
1960+ /// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop.
1961+ /// `grad` must have appropriate signature as described in the doc of
1962+ /// GradientDef in tensorflow/core/framework/function.proto.
1963+ /// - parameters:
1964+ /// - function: function to add
1965+ /// - grad: the gradient function to add with.
1966+ /// - throws: Panic.FAULT
1967+ public func copy( function: Function , grad: Function ? = nil ) throws {
1968+ let status = try Status ( )
1969+ TFLib . GraphCopyFunction ( self . graph, function. ref, grad? . ref, status. status)
1970+ guard status. code == . OK else {
1971+ throw Panic . FAULT ( reason: status. message)
1972+ }
1973+ }
1974+
19441975 /// Function is a grouping of operations with defined inputs and outputs.
19451976 /// Once created and added to graphs, functions can be invoked by creating an
19461977 /// operation whose operation type matches the function name.
@@ -2023,17 +2054,17 @@ public class TensorFlow {
20232054 return nil
20242055 }
20252056 }
2026- }
20272057
2028- /// get definition
2029- public var def : FunctionDef ? {
2030- if let buf = self . buffer, let proto = buf. data {
2031- return try ? FunctionDef ( serializedData: proto)
2032- } else {
2033- return nil
2058+ /// get definition
2059+ public var definition : FunctionDef ? {
2060+ if let buf = self . buffer, let proto = buf. data {
2061+ return try ? FunctionDef ( serializedData: proto)
2062+ } else {
2063+ return nil
2064+ }
20342065 }
2035- }
20362066
2067+ }
20372068 } //end graph
20382069
20392070 /// class wrapper of Graph Definition Options
0 commit comments