diff --git a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java index 13a11dc..c1f53a2 100644 --- a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java +++ b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java @@ -69,6 +69,8 @@ public final class TensorflowNLU implements NLUService { private TensorflowModel nluModel = null; private TFNLUOutput outputParser = null; private int maxTokens; + private float confidenceThreshold; + private String fallbackIntent; private volatile boolean ready = false; @@ -92,6 +94,8 @@ private TensorflowNLU(Builder builder) { this.loadThread.start(); this.padTokenId = this.textEncoder.encodeSingle("[PAD]"); this.sepTokenId = this.textEncoder.encodeSingle("[SEP]"); + this.confidenceThreshold = builder.confidenceThreshold; + this.fallbackIntent = builder.fallbackIntent; } private void initParsers(Map parserClasses) { @@ -218,6 +222,15 @@ private NLUResult tfClassify(String utterance, NLUContext nluContext) { // interpret model outputs Tuple prediction = outputParser.getIntent( this.nluModel.outputs(0)); + float confidence = prediction.second(); + + if (confidence < this.confidenceThreshold) { + return new NLUResult.Builder(utterance) + .withIntent(this.fallbackIntent) + .withConfidence(confidence) + .build(); + } + Metadata.Intent intent = prediction.first(); nluContext.traceDebug("Intent: %s", intent.getName()); @@ -230,7 +243,7 @@ private NLUResult tfClassify(String utterance, NLUContext nluContext) { return new NLUResult.Builder(utterance) .withIntent(intent.getName()) - .withConfidence(prediction.second()) + .withConfidence(confidence) .withSlots(parsedSlots) .build(); } @@ -286,6 +299,8 @@ public static class Builder { private TensorflowModel.Loader modelLoader; private ThreadFactory threadFactory; private TextEncoder textEncoder; + private float confidenceThreshold; + private String fallbackIntent; /** * Creates a new builder instance. @@ -344,6 +359,24 @@ public Builder setTextEncoder(TextEncoder encoder) { return this; } + /** + * Sets a confidence threshold for classification, below which the + * specified fallback intent will be returned. + * + * @param confidence the lowest confidence value that will be accepted + * as a valid classification. + * @param fallback the name of the intent that will be returned if the + * model's confidence is below {@code confidence}. + * @return this + */ + public Builder setConfidenceThreshold(float confidence, + String fallback) { + + this.confidenceThreshold = confidence; + this.fallbackIntent = fallback; + return this; + } + /** * Sets a configuration value. * diff --git a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLUTest.java b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLUTest.java index 6c941bf..d181a02 100644 --- a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLUTest.java +++ b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLUTest.java @@ -73,7 +73,7 @@ public void classify() throws Exception { StringBuilder tooManyTokens = new StringBuilder(); for (int i = 0; i <= env.nlu.getMaxTokens(); i++) { - tooManyTokens.append("a "); + tooManyTokens.append("a "); } utterance = tooManyTokens.toString(); result = env.classify(utterance).get(); @@ -86,8 +86,9 @@ public void classify() throws Exception { assertTrue(result.getSlots().isEmpty()); utterance = "this code is for test 1"; + float conf = 0.75f; float[] intentResult = - buildIntentResult(2, env.metadata.getIntents().length); + buildIntentResult(2, env.metadata.getIntents().length, conf); float[] tagResult = new float[utterance.split(" ").length * env.metadata.getTags().length]; @@ -104,7 +105,7 @@ public void classify() throws Exception { assertNull(result.getError()); assertEquals("describe_test", result.getIntent()); - assertEquals(10.0, result.getConfidence()); + assertEquals(conf, result.getConfidence()); for (String slotName : slots.keySet()) { assertEquals(slots.get(slotName), result.getSlots().get(slotName)); } @@ -117,9 +118,10 @@ public void classify() throws Exception { // (which is incorrect, but we're just testing the slot extraction // logic here) utterance = "this bad code is for test 1"; - intentResult = buildIntentResult(2, env.metadata.getIntents().length); + intentResult = + buildIntentResult(2, env.metadata.getIntents().length, conf); tagResult = new float[utterance.split(" ").length - * env.metadata.getTags().length]; + * env.metadata.getTags().length]; setTag(tagResult, env.metadata.getTags().length, 0, 1); setTag(tagResult, env.metadata.getTags().length, 2, 1); setTag(tagResult, env.metadata.getTags().length, 6, 3); @@ -133,7 +135,7 @@ public void classify() throws Exception { assertNull(result.getError()); assertEquals("describe_test", result.getIntent()); - assertEquals(10.0, result.getConfidence()); + assertEquals(conf, result.getConfidence()); for (String slotName : slots.keySet()) { assertEquals(slots.get(slotName), result.getSlots().get(slotName)); } @@ -142,9 +144,37 @@ public void classify() throws Exception { assertTrue(result.getContext().isEmpty()); } - private float[] buildIntentResult(int index, int numIntents) { + @Test + public void testConfidenceThreshold() throws Exception { + TestEnv env = new TestEnv(testConfig()); + env.nluBuilder.setConfidenceThreshold(0.5f, "fallback"); + + String utterance = "how far is it to the moon?"; + float conf = 0.3f; + float[] intentResult = + buildIntentResult(1, env.metadata.getIntents().length, conf); + + // include some tags in the result to make sure they're ignored + float[] tagResult = + new float[utterance.split(" ").length + * env.metadata.getTags().length]; + setTag(tagResult, env.metadata.getTags().length, 0, 1); + setTag(tagResult, env.metadata.getTags().length, 1, 2); + env.testModel.setOutputs(intentResult, tagResult); + NLUResult result = env.classify(utterance).get(); + + assertNull(result.getError()); + assertEquals("fallback", result.getIntent()); + assertEquals(conf, result.getConfidence()); + assertTrue(result.getSlots().isEmpty()); + assertEquals(utterance, result.getUtterance()); + assertTrue(result.getContext().isEmpty()); + } + + private float[] buildIntentResult(int index, int numIntents, + float confidence) { float[] result = new float[numIntents]; - result[index] = 10; + result[index] = confidence; return result; }