|
| 1 | +package gate.plugin.learningframework.engines; |
| 2 | + |
| 3 | +import cc.mallet.types.FeatureVector; |
| 4 | +import cc.mallet.types.Instance; |
| 5 | +import gate.Annotation; |
| 6 | +import gate.AnnotationSet; |
| 7 | +import gate.lib.interaction.data.SparseDoubleVector; |
| 8 | +import gate.lib.interaction.process.Process4JsonStream; |
| 9 | +import gate.lib.interaction.process.Process4ObjectStream; |
| 10 | +import gate.lib.interaction.process.ProcessBase; |
| 11 | +import gate.lib.interaction.process.ProcessSimple; |
| 12 | +import gate.plugin.learningframework.EvaluationMethod; |
| 13 | +import gate.plugin.learningframework.Exporter; |
| 14 | +import gate.plugin.learningframework.GateClassification; |
| 15 | +import gate.plugin.learningframework.Globals; |
| 16 | +import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget; |
| 17 | +import gate.plugin.learningframework.mallet.LFPipe; |
| 18 | +import gate.util.GateRuntimeException; |
| 19 | +import java.io.File; |
| 20 | +import java.io.FileInputStream; |
| 21 | +import java.io.InputStreamReader; |
| 22 | +import java.util.ArrayList; |
| 23 | +import java.util.Collections; |
| 24 | +import java.util.HashMap; |
| 25 | +import java.util.List; |
| 26 | +import java.util.Map; |
| 27 | +import org.yaml.snakeyaml.Yaml; |
| 28 | + |
| 29 | +/** |
| 30 | + * An engine that represents Python Scikit Learn through en external process. |
| 31 | + * |
| 32 | + * This requires that the user configures the location of where sklearn-wrapper is installed. |
| 33 | + * This can be done by setting the environment variable SKLEARN_WRAPPPER_HOME, the Java property |
| 34 | + * gate.plugin.learningframework.sklearnwrapper.home or by adding another yaml file "sklearn.yaml" |
| 35 | + * to the data directory which contains the setting sklearnwrapper.home. |
| 36 | + * If the path starts with a slash |
| 37 | + * it is an absolute path, otherwise the path is resolved relative to the |
| 38 | + * directory. |
| 39 | + * |
| 40 | + * |
| 41 | + * @author Johann Petrak |
| 42 | + */ |
| 43 | +public class EngineSklearnExternal extends Engine { |
| 44 | + |
| 45 | + ProcessBase process; |
| 46 | + |
| 47 | + /** |
| 48 | + * Try to find the script running the sklearn-Wrapper command. |
| 49 | + * If apply is true, the executable for application is searched, |
| 50 | + * otherwise the one for training. |
| 51 | + * This checks the following settings (increasing priority): |
| 52 | + * environment variable SKLEARN_WRAPPER_HOME, |
| 53 | + * java property gate.plugin.learningframework.sklearnwrapper.home and |
| 54 | + * the setting "sklearnwrapper.home" in file "sklearn.yaml" in the data directory, |
| 55 | + * if it exists. |
| 56 | + * The setting for the sklearn wrapper home can be relative in which case it |
| 57 | + * will be resolved relative to the dataDirectory |
| 58 | + * @param dataDirectory |
| 59 | + * @return |
| 60 | + */ |
| 61 | + private File findWrapperCommand(File dataDirectory, boolean apply) { |
| 62 | + String homeDir = System.getenv("SKLEARN_WRAPPER_HOME"); |
| 63 | + String tmp = System.getProperty("gate.plugin.learningframework.sklearnwrapper.home"); |
| 64 | + if(tmp!=null) homeDir = tmp; |
| 65 | + File sklearnInfoFile = new File(dataDirectory,"sklearn.yaml"); |
| 66 | + if(sklearnInfoFile.exists()) { |
| 67 | + Yaml yaml = new Yaml(); |
| 68 | + Object obj; |
| 69 | + try { |
| 70 | + obj = yaml.load(new InputStreamReader(new FileInputStream(sklearnInfoFile),"UTF-8")); |
| 71 | + } catch (Exception ex) { |
| 72 | + throw new GateRuntimeException("Could not load yaml file "+sklearnInfoFile,ex); |
| 73 | + } |
| 74 | + tmp = null; |
| 75 | + if(obj instanceof Map) { |
| 76 | + Map map = (Map)obj; |
| 77 | + tmp = (String)map.get("sklearnwrapper.home"); |
| 78 | + } else { |
| 79 | + throw new GateRuntimeException("Info file has strange format: "+sklearnInfoFile.getAbsolutePath()); |
| 80 | + } |
| 81 | + if(tmp == null) { |
| 82 | + System.err.println("sklearn.yaml file present but does not contain sklearnwrapper.home setting"); |
| 83 | + } else { |
| 84 | + homeDir = tmp; |
| 85 | + } |
| 86 | + } |
| 87 | + if(homeDir == null) { |
| 88 | + throw new GateRuntimeException("SklearnWrapper home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn"); |
| 89 | + } |
| 90 | + File wrapperHome = new File(homeDir); |
| 91 | + if(!wrapperHome.isAbsolute()) { |
| 92 | + wrapperHome = new File(dataDirectory,homeDir); |
| 93 | + } |
| 94 | + if(!wrapperHome.isDirectory()) { |
| 95 | + throw new GateRuntimeException("SklearnWrapper home is not a directory: "+wrapperHome.getAbsolutePath()); |
| 96 | + } |
| 97 | + // Now, depending on the operating system, and on train/apply, |
| 98 | + // find the correct script to execute |
| 99 | + File commandFile; |
| 100 | + // we use the simple heuristic that if the file separator is "/" |
| 101 | + // we assume we can use the bash script, if it is "\" we use the windows |
| 102 | + // script and otherwise we give up |
| 103 | + boolean linuxLike = System.getProperty("file.separator").equals("/"); |
| 104 | + boolean windowsLike = System.getProperty("file.separator").equals("\\"); |
| 105 | + if(linuxLike) { |
| 106 | + if(apply) |
| 107 | + commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperApply.sh"); |
| 108 | + else |
| 109 | + commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperTrain.sh"); |
| 110 | + } else if(windowsLike) { |
| 111 | + if(apply) |
| 112 | + commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperApply.cmd"); |
| 113 | + else |
| 114 | + commandFile = new File(new File(wrapperHome,"bin"),"sklearnWrapperTrain.cmd"); |
| 115 | + } else { |
| 116 | + throw new GateRuntimeException("It appears this OS is not supported"); |
| 117 | + } |
| 118 | + commandFile = commandFile.isAbsolute() ? |
| 119 | + commandFile : |
| 120 | + new File(dataDirectory,commandFile.getPath()); |
| 121 | + if(!commandFile.canExecute()) { |
| 122 | + throw new GateRuntimeException("Not an executable file or not found: "+commandFile+" please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn"); |
| 123 | + } |
| 124 | + return commandFile; |
| 125 | + } |
| 126 | + |
| 127 | + |
| 128 | + @Override |
| 129 | + protected void loadModel(File directory, String parms) { |
| 130 | + // Instead of loading a model, this establishes a connection with the |
| 131 | + // external sklearn process. |
| 132 | + File commandFile = findWrapperCommand(directory, true); |
| 133 | + String modelFileName = new File(directory,"sklmodel").getAbsolutePath(); |
| 134 | + String finalCommand = commandFile.getAbsolutePath()+" "+modelFileName; |
| 135 | + System.err.println("Running: "+finalCommand); |
| 136 | + // Create a fake Model jsut to make LF_Apply... happy which checks if this is null |
| 137 | + model = "ExternalSklearnWrapperModel"; |
| 138 | + process = new Process4JsonStream(directory,finalCommand); |
| 139 | + } |
| 140 | + |
| 141 | + @Override |
| 142 | + protected void saveModel(File directory) { |
| 143 | + // NOTE: we do not need to save the model here because the external |
| 144 | + // sklearnWrapper command does this. |
| 145 | + // However we still need to make sure a usable info file is saved! |
| 146 | + info.engineClass = EngineSklearnExternal.class.getName(); |
| 147 | + info.save(directory); |
| 148 | + } |
| 149 | + |
| 150 | + @Override |
| 151 | + public void trainModel(File dataDirectory, String instanceType, String parms) { |
| 152 | + // invoke the sklearn wrapper for training |
| 153 | + // NOTE: for this the first word in parms must be the full sklearn class name, the rest are parms |
| 154 | + if(parms == null || parms.isEmpty()) { |
| 155 | + throw new GateRuntimeException("Cannot train using SklearnWrapper, algorithmParameter must contain fulle SciKit Learn algorithm class name as first word"); |
| 156 | + } |
| 157 | + String sklearnClass = null; |
| 158 | + String sklearnParms = ""; |
| 159 | + int spaceIdx = parms.indexOf(" "); |
| 160 | + if(spaceIdx<0) { |
| 161 | + sklearnClass = parms; |
| 162 | + } else { |
| 163 | + sklearnClass = parms.substring(0,spaceIdx); |
| 164 | + sklearnParms = parms.substring(spaceIdx).trim(); |
| 165 | + } |
| 166 | + File commandFile = findWrapperCommand(dataDirectory, false); |
| 167 | + // Export the data |
| 168 | + // Note: any scaling was already done in the PR before calling this method! |
| 169 | + // find out if we train classification or regression |
| 170 | + // TODO: NOTE: not sure if classification/regression matters here as long as |
| 171 | + // the actual exporter class does the right thing based on the corpus representation! |
| 172 | + Exporter.export(getCorpusRepresentationMallet(), |
| 173 | + Exporter.EXPORTER_MATRIXMARKET2_CLASS, dataDirectory, instanceType, parms); |
| 174 | + String dataFileName = dataDirectory.getAbsolutePath()+File.separator; |
| 175 | + String modelFileName = new File(dataDirectory, "sklmodel").getAbsolutePath(); |
| 176 | + String finalCommand = commandFile.getAbsolutePath()+" "+dataFileName+" "+modelFileName+" "+sklearnClass+" "+sklearnParms; |
| 177 | + System.err.println("Running: "+finalCommand); |
| 178 | + // Create a fake Model jsut to make LF_Apply... happy which checks if this is null |
| 179 | + model = "ExternalSklearnWrapperModel"; |
| 180 | + |
| 181 | + process = new ProcessSimple(dataDirectory,finalCommand); |
| 182 | + process.waitFor(); |
| 183 | + } |
| 184 | + |
| 185 | + @Override |
| 186 | + public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) { |
| 187 | + throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates. |
| 188 | + } |
| 189 | + |
| 190 | + @Override |
| 191 | + public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet inputAS, AnnotationSet sequenceAS, String parms) { |
| 192 | + CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentationMallet; |
| 193 | + data.stopGrowth(); |
| 194 | + int nrCols = data.getPipe().getDataAlphabet().size(); |
| 195 | + //System.err.println("Running EngineSklearn.classify on document "+instanceAS.getDocument().getName()); |
| 196 | + List<GateClassification> gcs = new ArrayList<GateClassification>(); |
| 197 | + LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe(); |
| 198 | + ArrayList<String> classList = null; |
| 199 | + // If we have a classification problem, pre-calculate the class label list |
| 200 | + if(pipe.getTargetAlphabet() != null) { |
| 201 | + classList = new ArrayList<String>(); |
| 202 | + for(int i = 0; i<pipe.getTargetAlphabet().size(); i++) { |
| 203 | + String labelstr = (String) pipe.getTargetAlphabet().lookupObject(i); |
| 204 | + classList.add(labelstr); |
| 205 | + } |
| 206 | + } |
| 207 | + // create the datastructure we need for the application script: |
| 208 | + // a map that contains the following fields: |
| 209 | + // - cmd: either STOP or CSR1 |
| 210 | + // - values: the non-zero values, for increasing rows and increasing cols within rows |
| 211 | + // - rowinds: for the k-th value which row number it is in |
| 212 | + // - colinds: for the k-th value which column number (location index) it is in |
| 213 | + // - shaperows: number of rows in total |
| 214 | + // - shapecols: maximum number of cols in a vector |
| 215 | + Map map = new HashMap<String,Object>(); |
| 216 | + map.put("cmd", "CSR1"); |
| 217 | + ArrayList<Double> values = new ArrayList<Double>(); |
| 218 | + ArrayList<Integer> rowinds = new ArrayList<Integer>(); |
| 219 | + ArrayList<Integer> colinds = new ArrayList<Integer>(); |
| 220 | + int rowIndex = 0; |
| 221 | + List<Annotation> instances = instanceAS.inDocumentOrder(); |
| 222 | + for(Annotation instAnn : instances) { |
| 223 | + Instance inst = data.extractIndependentFeatures(instAnn, inputAS); |
| 224 | + |
| 225 | + //FeatureVector fv = (FeatureVector)inst.getData(); |
| 226 | + //System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations()); |
| 227 | + inst = pipe.instanceFrom(inst); |
| 228 | + |
| 229 | + FeatureVector fv = (FeatureVector)inst.getData(); |
| 230 | + //System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations()); |
| 231 | + |
| 232 | + // Convert to the sparse vector we use to send to the weka process |
| 233 | + int locs = fv.numLocations(); |
| 234 | + SparseDoubleVector sdv = new SparseDoubleVector(locs); |
| 235 | + for(int i=0;i<locs;i++) { |
| 236 | + int index = fv.indexAtLocation(i); |
| 237 | + values.add(fv.value(index)); |
| 238 | + rowinds.add(rowIndex); |
| 239 | + colinds.add(index); |
| 240 | + } |
| 241 | + rowIndex++; |
| 242 | + } |
| 243 | + // send the matrix data over to the weka process |
| 244 | + map.put("values", values); |
| 245 | + map.put("rowinds", rowinds); |
| 246 | + map.put("colinds",colinds); |
| 247 | + map.put("shaperows", rowIndex); |
| 248 | + map.put("shapecols",nrCols); |
| 249 | + process.writeObject(map); |
| 250 | + // get the result back |
| 251 | + Object ret = process.readObject(); |
| 252 | + Map<String,Object> response = null; |
| 253 | + if(ret instanceof Map) { |
| 254 | + response = (Map)ret; |
| 255 | + } |
| 256 | + if(response == null) { |
| 257 | + throw new RuntimeException("Got a response from Sklearn process which cannot be used: "+response); |
| 258 | + } |
| 259 | + // the response has the following format: |
| 260 | + // - status: should be "OK" or an error message |
| 261 | + // - targets: a vector of target indices/values |
| 262 | + // - probas: if probabilities are supported, a vector of vectors of class probabilities, otherwise null |
| 263 | + |
| 264 | + String status = (String)response.get("status"); |
| 265 | + if(status == null || !status.equals("OK")) { |
| 266 | + throw new RuntimeException("Status of response is not OK but "+status); |
| 267 | + } |
| 268 | + ArrayList<Double> targets = (ArrayList<Double>)response.get("targets"); |
| 269 | + ArrayList<ArrayList<Double>> probas = (ArrayList<ArrayList<Double>>)response.get("probas"); |
| 270 | + |
| 271 | + GateClassification gc = null; |
| 272 | + |
| 273 | + // now check if the mallet representation and the weka process agree |
| 274 | + // on if we have regression or classification |
| 275 | + if(pipe.getTargetAlphabet() == null) { |
| 276 | + // we expect a regression result, i.e probas should be null |
| 277 | + if(probas != null) { |
| 278 | + throw new RuntimeException("We think we have regression but the Sklearn process sent probabilities"); |
| 279 | + } |
| 280 | + } |
| 281 | + // now go through all the instances again and do the target assignment from the vector(s) we got |
| 282 | + int instNr = 0; |
| 283 | + for(Annotation instAnn : instances) { |
| 284 | + if(pipe.getTargetAlphabet() == null) { // we have regression |
| 285 | + gc = new GateClassification(instAnn, targets.get(instNr)); |
| 286 | + } else { |
| 287 | + int bestlabel = targets.get(instNr).intValue(); |
| 288 | + String cl |
| 289 | + = (String) pipe.getTargetAlphabet().lookupObject(bestlabel); |
| 290 | + double bestprob = Double.NaN; |
| 291 | + if(probas != null) { |
| 292 | + bestprob = Collections.max(probas.get(instNr)); |
| 293 | + } |
| 294 | + gc = new GateClassification( |
| 295 | + instAnn, cl, bestprob, classList, probas.get(instNr)); |
| 296 | + } |
| 297 | + gcs.add(gc); |
| 298 | + instNr++; |
| 299 | + } |
| 300 | + data.startGrowth(); |
| 301 | + return gcs; |
| 302 | + } |
| 303 | + |
| 304 | + @Override |
| 305 | + public void initializeAlgorithm(Algorithm algorithm, String parms) { |
| 306 | + // do not do anything |
| 307 | + } |
| 308 | + |
| 309 | + @Override |
| 310 | + protected void loadMalletCorpusRepresentation(File directory) { |
| 311 | + corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory); |
| 312 | + } |
| 313 | + |
| 314 | +} |
0 commit comments