Skip to content

Commit c9eebf6

Browse files
committed
Add Python SciKit-Learn support.
1 parent a9faef2 commit c9eebf6

File tree

9 files changed

+903
-19
lines changed

9 files changed

+903
-19
lines changed

build/ivy.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
-->
1919
<dependency org="cc.mallet" name="mallet" rev="2.0.8" />
2020
<dependency org="org.jdom" name="jdom" rev="1.1"/>
21+
<dependency org="com.fasterxml.jackson.core" name="jackson-databind" rev="2.6.3" />
2122
<!-- for simplicity, this is currently fixed in lib-static
2223
<dependency org="uk.ac.gate.lib" name="interaction" rev="1.0-SNAPSHOT"/>
2324
-->
3.75 KB
Binary file not shown.

src/gate/plugin/learningframework/engines/AlgorithmClassification.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public enum AlgorithmClassification implements Algorithm {
3030
// We only add this after figuring out exactly how it needs to get set up!
3131
// MALLET_SEQ_CRF_VGS(EngineMalletSeq.class,null), // ByValueGradients
3232
MALLET_SEQ_MEMM(EngineMalletSeq.class,null),
33-
WEKA_CL_WRAPPER(EngineWekaExternal.class,null);
33+
WEKA_CL_WRAPPER(EngineWekaExternal.class,null),
34+
SKLEARN_CL_WRAPPER(EngineSklearnExternal.class,null);
3435
private AlgorithmClassification() {
3536

3637
}

src/gate/plugin/learningframework/engines/AlgorithmRegression.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
*/
1313
public enum AlgorithmRegression implements Algorithm {
1414
LIBSVM_RG(EngineLibSVM.class,null),
15-
WEKA_RG_WRAPPER(EngineWekaExternal.class,null);
15+
WEKA_RG_WRAPPER(EngineWekaExternal.class,null),
16+
SKLEARN_RG_WRAPPER(EngineSklearnExternal.class,null);
1617
private AlgorithmRegression() {
1718

1819
}
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
package gate.plugin.learningframework.engines;
2+
3+
import cc.mallet.types.FeatureVector;
4+
import cc.mallet.types.Instance;
5+
import gate.Annotation;
6+
import gate.AnnotationSet;
7+
import gate.lib.interaction.data.SparseDoubleVector;
8+
import gate.lib.interaction.process.Process4JsonStream;
9+
import gate.lib.interaction.process.Process4ObjectStream;
10+
import gate.lib.interaction.process.ProcessBase;
11+
import gate.lib.interaction.process.ProcessSimple;
12+
import gate.plugin.learningframework.EvaluationMethod;
13+
import gate.plugin.learningframework.Exporter;
14+
import gate.plugin.learningframework.GateClassification;
15+
import gate.plugin.learningframework.Globals;
16+
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
17+
import gate.plugin.learningframework.mallet.LFPipe;
18+
import gate.util.GateRuntimeException;
19+
import java.io.File;
20+
import java.io.FileInputStream;
21+
import java.io.InputStreamReader;
22+
import java.util.ArrayList;
23+
import java.util.Collections;
24+
import java.util.HashMap;
25+
import java.util.List;
26+
import java.util.Map;
27+
import org.yaml.snakeyaml.Yaml;
28+
29+
/**
30+
* An engine that represents Python Scikit Learn through en external process.
31+
*
32+
* This requires that the user configures the location of where sklearn-wrapper is installed.
33+
* This can be done by setting the environment variable SKLEARN_WRAPPPER_HOME, the Java property
34+
* gate.plugin.learningframework.sklearnwrapper.home or by adding another yaml file "sklearn.yaml"
35+
* to the data directory which contains the setting sklearnwrapper.home.
36+
* If the path starts with a slash
37+
* it is an absolute path, otherwise the path is resolved relative to the
38+
* directory.
39+
*
40+
*
41+
* @author Johann Petrak
42+
*/
43+
public class EngineSklearnExternal extends Engine {
44+
45+
ProcessBase process;
46+
47+
/**
48+
* Try to find the script running the sklearn-Wrapper command.
49+
* If apply is true, the executable for application is searched,
50+
* otherwise the one for training.
51+
* This checks the following settings (increasing priority):
52+
* environment variable SKLEARN_WRAPPER_HOME,
53+
* java property gate.plugin.learningframework.sklearnwrapper.home and
54+
* the setting "sklearnwrapper.home" in file "sklearn.yaml" in the data directory,
55+
* if it exists.
56+
* The setting for the sklearn wrapper home can be relative in which case it
57+
* will be resolved relative to the dataDirectory
58+
* @param dataDirectory
59+
* @return
60+
*/
61+
private File findWrapperCommand(File dataDirectory, boolean apply) {
62+
String homeDir = System.getenv("SKLEARN_WRAPPER_HOME");
63+
String tmp = System.getProperty("gate.plugin.learningframework.sklearnwrapper.home");
64+
if(tmp!=null) homeDir = tmp;
65+
File sklearnInfoFile = new File(dataDirectory,"sklearn.yaml");
66+
if(sklearnInfoFile.exists()) {
67+
Yaml yaml = new Yaml();
68+
Object obj;
69+
try {
70+
obj = yaml.load(new InputStreamReader(new FileInputStream(sklearnInfoFile),"UTF-8"));
71+
} catch (Exception ex) {
72+
throw new GateRuntimeException("Could not load yaml file "+sklearnInfoFile,ex);
73+
}
74+
tmp = null;
75+
if(obj instanceof Map) {
76+
Map map = (Map)obj;
77+
tmp = (String)map.get("sklearnwrapper.home");
78+
} else {
79+
throw new GateRuntimeException("Info file has strange format: "+sklearnInfoFile.getAbsolutePath());
80+
}
81+
if(tmp == null) {
82+
System.err.println("sklearn.yaml file present but does not contain sklearnwrapper.home setting");
83+
} else {
84+
homeDir = tmp;
85+
}
86+
}
87+
if(homeDir == null) {
88+
throw new GateRuntimeException("SklearnWrapper home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn");
89+
}
90+
File wrapperHome = new File(homeDir);
91+
if(!wrapperHome.isAbsolute()) {
92+
wrapperHome = new File(dataDirectory,homeDir);
93+
}
94+
if(!wrapperHome.isDirectory()) {
95+
throw new GateRuntimeException("SklearnWrapper home is not a directory: "+wrapperHome.getAbsolutePath());
96+
}
97+
// Now, depending on the operating system, and on train/apply,
98+
// find the correct script to execute
99+
File commandFile;
100+
// we use the simple heuristic that if the file separator is "/"
101+
// we assume we can use the bash script, if it is "\" we use the windows
102+
// script and otherwise we give up
103+
boolean linuxLike = System.getProperty("file.separator").equals("/");
104+
boolean windowsLike = System.getProperty("file.separator").equals("\\");
105+
if(linuxLike) {
106+
if(apply)
107+
commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperApply.sh");
108+
else
109+
commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperTrain.sh");
110+
} else if(windowsLike) {
111+
if(apply)
112+
commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperApply.cmd");
113+
else
114+
commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperTrain.cmd");
115+
} else {
116+
throw new GateRuntimeException("It appears this OS is not supported");
117+
}
118+
commandFile = commandFile.isAbsolute() ?
119+
commandFile :
120+
new File(dataDirectory,commandFile.getPath());
121+
if(!commandFile.canExecute()) {
122+
throw new GateRuntimeException("Not an executable file or not found: "+commandFile+" please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn");
123+
}
124+
return commandFile;
125+
}
126+
127+
128+
@Override
129+
protected void loadModel(File directory, String parms) {
130+
// Instead of loading a model, this establishes a connection with the
131+
// external sklearn process.
132+
File commandFile = findWrapperCommand(directory, true);
133+
String modelFileName = new File(directory,"sklmodel").getAbsolutePath();
134+
String finalCommand = commandFile.getAbsolutePath()+" "+modelFileName;
135+
System.err.println("Running: "+finalCommand);
136+
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
137+
model = "ExternalSklearnWrapperModel";
138+
process = new Process4JsonStream(directory,finalCommand);
139+
}
140+
141+
@Override
142+
protected void saveModel(File directory) {
143+
// NOTE: we do not need to save the model here because the external
144+
// sklearnWrapper command does this.
145+
// However we still need to make sure a usable info file is saved!
146+
info.engineClass = EngineSklearnExternal.class.getName();
147+
info.save(directory);
148+
}
149+
150+
@Override
151+
public void trainModel(File dataDirectory, String instanceType, String parms) {
152+
// invoke the sklearn wrapper for training
153+
// NOTE: for this the first word in parms must be the full sklearn class name, the rest are parms
154+
if(parms == null || parms.isEmpty()) {
155+
throw new GateRuntimeException("Cannot train using SklearnWrapper, algorithmParameter must contain fulle SciKit Learn algorithm class name as first word");
156+
}
157+
String sklearnClass = null;
158+
String sklearnParms = "";
159+
int spaceIdx = parms.indexOf(" ");
160+
if(spaceIdx<0) {
161+
sklearnClass = parms;
162+
} else {
163+
sklearnClass = parms.substring(0,spaceIdx);
164+
sklearnParms = parms.substring(spaceIdx).trim();
165+
}
166+
File commandFile = findWrapperCommand(dataDirectory, false);
167+
// Export the data
168+
// Note: any scaling was already done in the PR before calling this method!
169+
// find out if we train classification or regression
170+
// TODO: NOTE: not sure if classification/regression matters here as long as
171+
// the actual exporter class does the right thing based on the corpus representation!
172+
Exporter.export(getCorpusRepresentationMallet(),
173+
Exporter.EXPORTER_MATRIXMARKET2_CLASS, dataDirectory, instanceType, parms);
174+
String dataFileName = dataDirectory.getAbsolutePath()+File.separator;
175+
String modelFileName = new File(dataDirectory, "sklmodel").getAbsolutePath();
176+
String finalCommand = commandFile.getAbsolutePath()+" "+dataFileName+" "+modelFileName+" "+sklearnClass+" "+sklearnParms;
177+
System.err.println("Running: "+finalCommand);
178+
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
179+
model = "ExternalSklearnWrapperModel";
180+
181+
process = new ProcessSimple(dataDirectory,finalCommand);
182+
process.waitFor();
183+
}
184+
185+
@Override
186+
public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) {
187+
throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
188+
}
189+
190+
@Override
191+
public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet inputAS, AnnotationSet sequenceAS, String parms) {
192+
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentationMallet;
193+
data.stopGrowth();
194+
int nrCols = data.getPipe().getDataAlphabet().size();
195+
//System.err.println("Running EngineSklearn.classify on document "+instanceAS.getDocument().getName());
196+
List<GateClassification> gcs = new ArrayList<GateClassification>();
197+
LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe();
198+
ArrayList<String> classList = null;
199+
// If we have a classification problem, pre-calculate the class label list
200+
if(pipe.getTargetAlphabet() != null) {
201+
classList = new ArrayList<String>();
202+
for(int i = 0; i<pipe.getTargetAlphabet().size(); i++) {
203+
String labelstr = (String) pipe.getTargetAlphabet().lookupObject(i);
204+
classList.add(labelstr);
205+
}
206+
}
207+
// create the datastructure we need for the application script:
208+
// a map that contains the following fields:
209+
// - cmd: either STOP or CSR1
210+
// - values: the non-zero values, for increasing rows and increasing cols within rows
211+
// - rowinds: for the k-th value which row number it is in
212+
// - colinds: for the k-th value which column number (location index) it is in
213+
// - shaperows: number of rows in total
214+
// - shapecols: maximum number of cols in a vector
215+
Map map = new HashMap<String,Object>();
216+
map.put("cmd", "CSR1");
217+
ArrayList<Double> values = new ArrayList<Double>();
218+
ArrayList<Integer> rowinds = new ArrayList<Integer>();
219+
ArrayList<Integer> colinds = new ArrayList<Integer>();
220+
int rowIndex = 0;
221+
List<Annotation> instances = instanceAS.inDocumentOrder();
222+
for(Annotation instAnn : instances) {
223+
Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
224+
225+
//FeatureVector fv = (FeatureVector)inst.getData();
226+
//System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
227+
inst = pipe.instanceFrom(inst);
228+
229+
FeatureVector fv = (FeatureVector)inst.getData();
230+
//System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
231+
232+
// Convert to the sparse vector we use to send to the weka process
233+
int locs = fv.numLocations();
234+
SparseDoubleVector sdv = new SparseDoubleVector(locs);
235+
for(int i=0;i<locs;i++) {
236+
int index = fv.indexAtLocation(i);
237+
values.add(fv.value(index));
238+
rowinds.add(rowIndex);
239+
colinds.add(index);
240+
}
241+
rowIndex++;
242+
}
243+
// send the matrix data over to the weka process
244+
map.put("values", values);
245+
map.put("rowinds", rowinds);
246+
map.put("colinds",colinds);
247+
map.put("shaperows", rowIndex);
248+
map.put("shapecols",nrCols);
249+
process.writeObject(map);
250+
// get the result back
251+
Object ret = process.readObject();
252+
Map<String,Object> response = null;
253+
if(ret instanceof Map) {
254+
response = (Map)ret;
255+
}
256+
if(response == null) {
257+
throw new RuntimeException("Got a response from Sklearn process which cannot be used: "+response);
258+
}
259+
// the response has the following format:
260+
// - status: should be "OK" or an error message
261+
// - targets: a vector of target indices/values
262+
// - probas: if probabilities are supported, a vector of vectors of class probabilities, otherwise null
263+
264+
String status = (String)response.get("status");
265+
if(status == null || !status.equals("OK")) {
266+
throw new RuntimeException("Status of response is not OK but "+status);
267+
}
268+
ArrayList<Double> targets = (ArrayList<Double>)response.get("targets");
269+
ArrayList<ArrayList<Double>> probas = (ArrayList<ArrayList<Double>>)response.get("probas");
270+
271+
GateClassification gc = null;
272+
273+
// now check if the mallet representation and the weka process agree
274+
// on if we have regression or classification
275+
if(pipe.getTargetAlphabet() == null) {
276+
// we expect a regression result, i.e probas should be null
277+
if(probas != null) {
278+
throw new RuntimeException("We think we have regression but the Sklearn process sent probabilities");
279+
}
280+
}
281+
// now go through all the instances again and do the target assignment from the vector(s) we got
282+
int instNr = 0;
283+
for(Annotation instAnn : instances) {
284+
if(pipe.getTargetAlphabet() == null) { // we have regression
285+
gc = new GateClassification(instAnn, targets.get(instNr));
286+
} else {
287+
int bestlabel = targets.get(instNr).intValue();
288+
String cl
289+
= (String) pipe.getTargetAlphabet().lookupObject(bestlabel);
290+
double bestprob = Double.NaN;
291+
if(probas != null) {
292+
bestprob = Collections.max(probas.get(instNr));
293+
}
294+
gc = new GateClassification(
295+
instAnn, cl, bestprob, classList, probas.get(instNr));
296+
}
297+
gcs.add(gc);
298+
instNr++;
299+
}
300+
data.startGrowth();
301+
return gcs;
302+
}
303+
304+
@Override
305+
public void initializeAlgorithm(Algorithm algorithm, String parms) {
306+
// do not do anything
307+
}
308+
309+
@Override
310+
protected void loadMalletCorpusRepresentation(File directory) {
311+
corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory);
312+
}
313+
314+
}

src/gate/plugin/learningframework/engines/EngineWekaExternal.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,18 @@
2525

2626
/**
2727
* An engine that represents Weka through en external process.
28-
* This can only be used for application of the model at the moment.
2928
*
30-
* For now, this engine gets only used for application, so it gets only
31-
* created by Engine.loadEngine. The only methods called in this context
32-
* are loadModel() and loadMalletCorpusRepresentation() and if the user
33-
* also specified an algorithm class in the info file, initializeAlgorithm,
34-
* which at the moment does nothing.
3529
*
36-
* This requires that the user also places a second yaml file into the
37-
* directory: weka.yaml. This file needs to contain the following fields:
38-
* <ul>
39-
* <li>path: the path to a script or program that starts the weka wrapper
40-
* (see the weka-wrapper project on github). If the path starts with a slash
30+
* This requires that the user configures the location of where weka-wrapper is installed.
31+
* This can be done by setting the environment variable WEKA_WRAPPPER_HOME, the Java property
32+
* gate.plugin.learningframework.wekawrapper.home or by adding another yaml file "weka.yaml"
33+
* to the data directory which contains the setting wekawrapper.home.
34+
* If the path starts with a slash
4135
* it is an absolute path, otherwise the path is resolved relative to the
4236
* directory.
43-
* </ul>
4437
*
45-
* The directory also needs to contain files lf.model, pipe.pipe, header.arff
38+
* The data directory also needs to contain files lf.model, pipe.pipe, header.arff
4639
*
47-
* The weka-wrapper command will then get invoked with the first parameter
48-
* being the path to the model and the second parameter being the path to
49-
* the head file.
50-
*
5140
*
5241
* @author Johann Petrak
5342
*/

0 commit comments

Comments
 (0)