@@ -44,6 +44,12 @@ public class EngineSklearnExternal extends Engine {
44
44
45
45
ProcessBase process ;
46
46
47
+ // These variables get set from the wrapper-specific config file, java properties or
48
+ // environment variables.
49
+ private String shellcmd = null ;
50
+ private String shellparms = null ;
51
+ private String wrapperhome = null ;
52
+
47
53
/**
48
54
* Try to find the script running the sklearn-Wrapper command.
49
55
* If apply is true, the executable for application is searched,
@@ -72,17 +78,20 @@ private File findWrapperCommand(File dataDirectory, boolean apply) {
72
78
throw new GateRuntimeException ("Could not load yaml file " +sklearnInfoFile ,ex );
73
79
}
74
80
tmp = null ;
81
+ Map map = null ;
75
82
if (obj instanceof Map ) {
76
- Map map = (Map )obj ;
83
+ map = (Map )obj ;
77
84
tmp = (String )map .get ("sklearnwrapper.home" );
78
85
} else {
79
86
throw new GateRuntimeException ("Info file has strange format: " +sklearnInfoFile .getAbsolutePath ());
80
87
}
81
- if (tmp == null ) {
82
- System .err .println ("sklearn.yaml file present but does not contain sklearnwrapper.home setting" );
83
- } else {
88
+ if (tmp != null ) {
84
89
homeDir = tmp ;
85
90
}
91
+ // Also get any other settings that may be present:
92
+ // shell command
93
+ shellcmd = (String )map .get ("shellcmd" );
94
+ shellparms = (String )map .get ("shellparms" );
86
95
}
87
96
if (homeDir == null ) {
88
97
throw new GateRuntimeException ("SklearnWrapper home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn" );
@@ -94,6 +103,7 @@ private File findWrapperCommand(File dataDirectory, boolean apply) {
94
103
if (!wrapperHome .isDirectory ()) {
95
104
throw new GateRuntimeException ("SklearnWrapper home is not a directory: " +wrapperHome .getAbsolutePath ());
96
105
}
106
+ wrapperhome = wrapperHome .getAbsolutePath ();
97
107
// Now, depending on the operating system, and on train/apply,
98
108
// find the correct script to execute
99
109
File commandFile ;
@@ -129,9 +139,18 @@ private File findWrapperCommand(File dataDirectory, boolean apply) {
129
139
protected void loadModel (File directory , String parms ) {
130
140
// Instead of loading a model, this establishes a connection with the
131
141
// external sklearn process.
142
+
132
143
File commandFile = findWrapperCommand (directory , true );
133
144
String modelFileName = new File (directory ,"sklmodel" ).getAbsolutePath ();
134
- String finalCommand = commandFile .getAbsolutePath ()+" " +modelFileName ;
145
+ String finalCommand = commandFile .getAbsolutePath ()+" " +wrapperhome +" " +modelFileName ;
146
+ // if we have a shell command prepend that, and if we have shell parms too, include them
147
+ if (shellcmd != null ) {
148
+ String tmp = shellcmd ;
149
+ if (shellparms != null ) {
150
+ shellcmd += " " + shellparms ;
151
+ }
152
+ finalCommand = shellcmd + " " + finalCommand ;
153
+ }
135
154
System .err .println ("Running: " +finalCommand );
136
155
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
137
156
model = "ExternalSklearnWrapperModel" ;
@@ -173,7 +192,15 @@ public void trainModel(File dataDirectory, String instanceType, String parms) {
173
192
Exporter .EXPORTER_MATRIXMARKET2_CLASS , dataDirectory , instanceType , parms );
174
193
String dataFileName = dataDirectory .getAbsolutePath ()+File .separator ;
175
194
String modelFileName = new File (dataDirectory , "sklmodel" ).getAbsolutePath ();
176
- String finalCommand = commandFile .getAbsolutePath ()+" " +dataFileName +" " +modelFileName +" " +sklearnClass +" " +sklearnParms ;
195
+ String finalCommand = commandFile .getAbsolutePath ()+" " +wrapperhome +" " +dataFileName +" " +modelFileName +" " +sklearnClass +" " +sklearnParms ;
196
+ // if we have a shell command prepend that, and if we have shell parms too, include them
197
+ if (shellcmd != null ) {
198
+ String tmp = shellcmd ;
199
+ if (shellparms != null ) {
200
+ shellcmd += " " + shellparms ;
201
+ }
202
+ finalCommand = shellcmd + " " + finalCommand ;
203
+ }
177
204
System .err .println ("Running: " +finalCommand );
178
205
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
179
206
model = "ExternalSklearnWrapperModel" ;
@@ -311,4 +338,5 @@ protected void loadMalletCorpusRepresentation(File directory) {
311
338
corpusRepresentationMallet = CorpusRepresentationMalletTarget .load (directory );
312
339
}
313
340
341
+
314
342
}
0 commit comments