Skip to content

Commit 93a827e

Browse files
committed
Map experimental C (actually C++) API for gradient tape
1 parent 0d73a9b commit 93a827e

File tree

88 files changed

+2504
-69
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+2504
-69
lines changed

tensorflow-core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
<javacpp.platform.macosx-x86_64.extension>macosx-x86_64${javacpp.platform.extension}</javacpp.platform.macosx-x86_64.extension>
6262
<javacpp.platform.windows-x86.extension>windows-x86${javacpp.platform.extension}</javacpp.platform.windows-x86.extension>
6363
<javacpp.platform.windows-x86_64.extension>windows-x86_64${javacpp.platform.extension}</javacpp.platform.windows-x86_64.extension>
64-
<javacpp.version>1.5.4</javacpp.version>
64+
<javacpp.version>1.5.5</javacpp.version>
6565
</properties>
6666

6767
<profiles>

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,19 @@
141141
</execution>
142142
</executions>
143143
</plugin>
144+
<plugin>
145+
<artifactId>maven-resources-plugin</artifactId>
146+
<version>3.1.0</version>
147+
<executions>
148+
<execution>
149+
<id>javacpp-parser</id>
150+
<phase>generate-sources</phase>
151+
<goals>
152+
<goal>resources</goal>
153+
</goals>
154+
</execution>
155+
</executions>
156+
</plugin>
144157
<plugin>
145158
<artifactId>maven-compiler-plugin</artifactId>
146159
<version>3.8.0</version>
@@ -209,7 +222,15 @@
209222
<classPath>${project.build.outputDirectory}</classPath>
210223
<includePaths>
211224
<includePath>${project.basedir}/</includePath>
225+
<includePath>${project.basedir}/bazel-bin/external/llvm-project/llvm/include/</includePath>
226+
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
227+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
228+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
229+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
230+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/farmhash_archive/src/</includePath>
231+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/llvm-project/llvm/include/</includePath>
212232
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
233+
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
213234
</includePaths>
214235
<linkPaths>
215236
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
@@ -315,6 +336,10 @@
315336
<outputDirectory>${project.build.directory}/native/org/tensorflow/internal/c_api/${native.classifier}/</outputDirectory>
316337
<skip>${javacpp.compiler.skip}</skip>
317338
<classOrPackageName>org.tensorflow.internal.c_api.**</classOrPackageName>
339+
<compilerOptions>
340+
<!-- TODO: Remove files from here as they get integrated into the Bazel build -->
341+
<compilerOption>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/tensorflow/c/eager/gradients.cc</compilerOption>
342+
</compilerOptions>
318343
<copyLibs>true</copyLibs>
319344
<copyResources>true</copyResources>
320345
</configuration>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
// Abstract interface to a context.
13+
//
14+
// This serves as a factory for creating `AbstractOperation`s and for
15+
// registering traced functions.
16+
// Operations creation within a context can only be executed in that context
17+
// (for now at least).
18+
// Implementations of the context may contain some state e.g. an execution
19+
// environment, a traced representation etc.
20+
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
21+
public class AbstractContext extends Pointer {
22+
static { Loader.load(); }
23+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
24+
public AbstractContext(Pointer p) { super(p); }
25+
26+
public native int getKind();
27+
28+
// Release any underlying resources, including the interface object.
29+
//
30+
// WARNING: The destructor of this class is marked as protected to disallow
31+
// clients from directly destroying this object since it may manage it's own
32+
// lifetime through ref counting. Thus clients MUST call Release() in order to
33+
// destroy an instance of this class.
34+
public native void Release();
35+
36+
// Creates an operation builder and ties it to this context.
37+
// The returned object can be used for setting operation's attributes,
38+
// adding inputs and finally executing (immediately or lazily as in tracing)
39+
// it in this context.
40+
public native AbstractOperation CreateOperation();
41+
42+
// Registers a function with this context, after this the function is
43+
// available to be called/referenced by its name in this context.
44+
public native @ByVal Status RegisterFunction(AbstractFunction arg0);
45+
// Remove a function. 'func' argument is the name of a previously added
46+
// FunctionDef. The name is in fdef.signature.name.
47+
public native @ByVal Status RemoveFunction(@StdString BytePointer func);
48+
public native @ByVal Status RemoveFunction(@StdString String func);
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
@Namespace("tensorflow::internal") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
12+
public class AbstractContextDeleter extends Pointer {
13+
static { Loader.load(); }
14+
/** Default native constructor. */
15+
public AbstractContextDeleter() { super((Pointer)null); allocate(); }
16+
/** Native array allocator. Access with {@link Pointer#position(long)}. */
17+
public AbstractContextDeleter(long size) { super((Pointer)null); allocateArray(size); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public AbstractContextDeleter(Pointer p) { super(p); }
20+
private native void allocate();
21+
private native void allocateArray(long size);
22+
@Override public AbstractContextDeleter position(long position) {
23+
return (AbstractContextDeleter)super.position(position);
24+
}
25+
@Override public AbstractContextDeleter getPointer(long i) {
26+
return new AbstractContextDeleter((Pointer)this).position(position + i);
27+
}
28+
29+
public native @Name("operator ()") void apply(AbstractContext p);
30+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
// A traced function: this hides the complexity of converting the serialized
13+
// representation between various supported formats e.g. FunctionDef and Mlir
14+
// function.
15+
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class AbstractFunction extends Pointer {
17+
static { Loader.load(); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public AbstractFunction(Pointer p) { super(p); }
20+
21+
// Returns which subclass is this instance of.
22+
public native int getKind();
23+
24+
// Returns the AbstractFunction as a FunctionDef.
25+
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") PointerPointer arg0);
26+
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") @ByPtrPtr Pointer arg0);
27+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy