diff --git a/pom.xml b/pom.xml index 9fe663f..36d6f94 100644 --- a/pom.xml +++ b/pom.xml @@ -22,8 +22,8 @@ - - 1.0.0-SNAPSHOT + 1.0.0-beta3 + 1.11.0 1.4.3 diff --git a/src/main/java/net/tzolov/cv/mtcnn/MtcnnService.java b/src/main/java/net/tzolov/cv/mtcnn/MtcnnService.java index 29fd3c0..4ec9931 100644 --- a/src/main/java/net/tzolov/cv/mtcnn/MtcnnService.java +++ b/src/main/java/net/tzolov/cv/mtcnn/MtcnnService.java @@ -15,6 +15,12 @@ */ package net.tzolov.cv.mtcnn; +import static net.tzolov.cv.mtcnn.MtcnnUtil.CHANNEL_COUNT; +import static net.tzolov.cv.mtcnn.MtcnnUtil.C_ORDERING; +import static org.nd4j.linalg.indexing.NDArrayIndex.all; +import static org.nd4j.linalg.indexing.NDArrayIndex.interval; +import static org.nd4j.linalg.indexing.NDArrayIndex.point; + import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -39,16 +45,9 @@ import org.nd4j.linalg.indexing.SpecifiedIndex; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner; -import org.tensorflow.framework.ConfigProto; - import org.springframework.core.io.DefaultResourceLoader; import org.springframework.util.Assert; - -import static net.tzolov.cv.mtcnn.MtcnnUtil.CHANNEL_COUNT; -import static net.tzolov.cv.mtcnn.MtcnnUtil.C_ORDERING; -import static org.nd4j.linalg.indexing.NDArrayIndex.all; -import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; +import org.tensorflow.framework.ConfigProto; /** * @author Christian Tzolov @@ -92,10 +91,18 @@ public MtcnnService(int minFaceSize, double scaleFactor, double[] stepsThreshold this.imageLoader = new Java2DNativeImageLoader(); - this.proposeNetGraphRunner = this.createGraphRunner(TF_PNET_MODEL_URI, "pnet/input"); - this.refineNetGraphRunner = this.createGraphRunner(TF_RNET_MODEL_URI, "rnet/input"); - this.outputNetGraphRunner = this.createGraphRunner(TF_ONET_MODEL_URI, "onet/input"); - + // this.proposeNetGraphRunner = this.createGraphRunner(TF_PNET_MODEL_URI, "pnet/input"); + // this.refineNetGraphRunner = this.createGraphRunner(TF_RNET_MODEL_URI, "rnet/input"); + // this.outputNetGraphRunner = this.createGraphRunner(TF_ONET_MODEL_URI, "onet/input"); + + this.proposeNetGraphRunner = + this.createGraphRunner(TF_PNET_MODEL_URI, "pnet/input", "pnet/conv4-2/BiasAdd", "pnet/prob1"); + this.refineNetGraphRunner = + this.createGraphRunner(TF_RNET_MODEL_URI, "rnet/input", "rnet/conv5-2/conv5-2", "rnet/prob1"); + this.outputNetGraphRunner = + this.createGraphRunner(TF_ONET_MODEL_URI, "onet/input", "onet/conv6-2/conv6-2", "onet/conv6-3/conv6-3", + "onet/prob1"); + // Experimental //proposeNetGraph = TFGraphMapper.getInstance().importGraph(new DefaultResourceLoader().getResource(TF_PNET_MODEL_URI).getInputStream()); //refineNetGraph = TFGraphMapper.getInstance().importGraph(new DefaultResourceLoader().getResource(TF_RNET_MODEL_URI).getInputStream()); @@ -104,10 +111,14 @@ public MtcnnService(int minFaceSize, double scaleFactor, double[] stepsThreshold private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLabel) { try { + + ConfigProto configProto = + ConfigProto.newBuilder().setInterOpParallelismThreads(4).setAllowSoftPlacement(true) + .setLogDevicePlacement(true).build(); + return new GraphRunner( IOUtils.toByteArray(new DefaultResourceLoader().getResource(tensorflowModelUri).getInputStream()), - Arrays.asList(inputLabel), - ConfigProto.getDefaultInstance()); + Arrays.asList(inputLabel), configProto); } catch (IOException e) { throw new IllegalStateException(String.format("Failed to load TF model [%s] and input [%s]:", @@ -115,6 +126,22 @@ private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLab } } + private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLabel, String... outLabel) { + try { + + ConfigProto cp = + ConfigProto.newBuilder().setInterOpParallelismThreads(Runtime.getRuntime().availableProcessors() * 2) + .setAllowSoftPlacement(true).setLogDevicePlacement(true).build(); + + return new GraphRunner( + IOUtils.toByteArray(new DefaultResourceLoader().getResource(tensorflowModelUri).getInputStream()), + Arrays.asList(inputLabel), Arrays.asList(outLabel), cp); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to load TF model [%s] and input [%s]:", tensorflowModelUri, inputLabel), e); + } + } + /** * Detects faces in an image, and returns bounding boxes and points for them. * @param imageUri Uri of the image to detect diff --git a/src/test/java/net/tzolov/cv/mtcnn/MtcnnServiceTest.java b/src/test/java/net/tzolov/cv/mtcnn/MtcnnServiceTest.java index 0a76e84..9f6ee94 100644 --- a/src/test/java/net/tzolov/cv/mtcnn/MtcnnServiceTest.java +++ b/src/test/java/net/tzolov/cv/mtcnn/MtcnnServiceTest.java @@ -15,6 +15,9 @@ */ package net.tzolov.cv.mtcnn; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + import java.io.IOException; import java.io.InputStream; @@ -22,13 +25,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.Before; import org.junit.Test; - import org.springframework.core.io.ClassPathResource; import org.springframework.util.StreamUtils; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.junit.Assert.assertThat; - /** * @author Christian Tzolov */ @@ -45,33 +44,16 @@ public void before() { @Test public void testSingeFace() throws IOException { FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/Anthony_Hopkins_0002.jpg"); - assertThat(toJson(faceAnnotations), equalTo("[{\"bbox\":{\"x\":75,\"y\":67,\"w\":95,\"h\":120}," + - "\"confidence\":0.9994938373565674," + - "\"landmarks\":[" + - "{\"type\":\"LEFT_EYE\",\"position\":{\"x\":101,\"y\":113}}," + - "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":147,\"y\":113}}," + - "{\"type\":\"NOSE\",\"position\":{\"x\":124,\"y\":136}}," + - "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":105,\"y\":160}}," + - "{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":146,\"y\":160}}]}]")); + assertThat(toJson(faceAnnotations), equalTo("[{\"bbox\":{\"x\":72,\"y\":64,\"w\":101,\"h\":124}," + + "\"confidence\":0.9997498393058777," + + "\"landmarks\":" + + "[{\"type\":\"LEFT_EYE\",\"position\":{\"x\":102,\"y\":113}}," + + "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":149,\"y\":113}}," + + "{\"type\":\"NOSE\",\"position\":{\"x\":125,\"y\":136}}," + + "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":104,\"y\":159}}," + + "{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":146,\"y\":160}}]}]")); } - - @Test - public void testSingeFace2() throws IOException { - try (InputStream is = new ClassPathResource("/MarkPollack.png").getInputStream()) { - byte[] image = StreamUtils.copyToByteArray(is); - FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection(image); - } - //FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/MarkPollack.png"); - //assertThat(toJson(faceAnnotations), equalTo("[{\"bbox\":{\"x\":75,\"y\":67,\"w\":95,\"h\":120}," + - // "\"confidence\":0.9994938373565674," + - // "\"landmarks\":[" + - // "{\"type\":\"LEFT_EYE\",\"position\":{\"x\":101,\"y\":113}}," + - // "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":147,\"y\":113}}," + - // "{\"type\":\"NOSE\",\"position\":{\"x\":124,\"y\":136}}," + - // "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":105,\"y\":160}}," + - // "{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":146,\"y\":160}}]}]")); - } - + @Test public void testFailToDetectFace() throws IOException { FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/broken.png"); @@ -81,22 +63,18 @@ public void testFailToDetectFace() throws IOException { @Test public void testMultiFaces() throws IOException { FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/VikiMaxiAdi.jpg"); - assertThat(toJson(faceAnnotations), equalTo("[{\"bbox\":{\"x\":102,\"y\":152,\"w\":70,\"h\":86}," + - "\"confidence\":0.9999865293502808," + - "\"landmarks\":[" + - "{\"type\":\"LEFT_EYE\",\"position\":{\"x\":122,\"y\":189}}," + - "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":154,\"y\":190}}," + - "{\"type\":\"NOSE\",\"position\":{\"x\":136,\"y\":203}}," + - "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":122,\"y\":219}}," + - "{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":151,\"y\":220}}]}," + - "{\"bbox\":{\"x\":332,\"y\":94,\"w\":57,\"h\":69}," + - "\"confidence\":0.9992565512657166," + - "\"landmarks\":[" + - "{\"type\":\"LEFT_EYE\",\"position\":{\"x\":346,\"y\":120}}," + - "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":373,\"y\":121}}," + - "{\"type\":\"NOSE\",\"position\":{\"x\":357,\"y\":134}}," + - "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":346,\"y\":147}}," + - "{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":370,\"y\":148}}]}]")); + assertThat(faceAnnotations.length, equalTo(3)); + assertThat(toJson(faceAnnotations), equalTo( + "[{\"bbox\":{\"x\":102,\"y\":155,\"w\":69,\"h\":81},\"confidence\":0.9999865293502808," + + "\"landmarks\":[{\"type\":\"LEFT_EYE\",\"position\":{\"x\":121,\"y\":188}}," + + "{\"type\":\"RIGHT_EYE\",\"position\":{\"x\":153,\"y\":190}}," + + "{\"type\":\"NOSE\",\"position\":{\"x\":135,\"y\":204}},{\"type\":\"MOUTH_LEFT\",\"position\"" + + ":{\"x\":120,\"y\":218}},{\"type\":\"MOUTH_RIGHT\",\"position\":{\"x\":148,\"y\":221}}]}," + + "{\"bbox\":{\"x\":333,\"y\":97,\"w\":54,\"h\":65},\"confidence\":0.9999747276306152,\"landmarks\":" + + "[{\"type\":\"LEFT_EYE\",\"position\":{\"x\":346,\"y\":120}},{\"type\":\"RIGHT_EYE\",\"position\"" + + ":{\"x\":372,\"y\":120}},{\"type\":\"NOSE\",\"position\":{\"x\":357,\"y\":133}}," + + "{\"type\":\"MOUTH_LEFT\",\"position\":{\"x\":347,\"y\":147}},{\"type\":\"MOUTH_RIGHT\"," + + "\"position\":{\"x\":369,\"y\":148}}]}]")); } @@ -104,12 +82,17 @@ public void testMultiFaces() throws IOException { public void testFacesAlignment() throws IOException { FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/pivotal-ipo-nyse.jpg"); assertThat(faceAnnotations.length, equalTo(7)); -// assertThat(toJson(faceAnnotations), equalTo("")); } + @Test + public void testFaceDetection() throws IOException { + FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection("classpath:/multiple-faces-5.jpg"); + assertThat(faceAnnotations.length, equalTo(5)); + } + @Test public void testFacesAlignment2() throws IOException { - try (InputStream is = new ClassPathResource("/MarkPollack.png").getInputStream()) { + try (InputStream is = new ClassPathResource("classpath:/MarkPollack.jpg").getInputStream()) { byte[] image = StreamUtils.copyToByteArray(is); FaceAnnotation[] faceAnnotations = mtcnnService.faceDetection(image); mtcnnService.faceAlignment(null, faceAnnotations, 44, 160, true); diff --git a/src/test/resources/MarkPollack.jpg b/src/test/resources/MarkPollack.jpg new file mode 100644 index 0000000..a9eebcf Binary files /dev/null and b/src/test/resources/MarkPollack.jpg differ diff --git a/src/test/resources/multiple-faces-5.jpg b/src/test/resources/multiple-faces-5.jpg new file mode 100644 index 0000000..e90dc86 Binary files /dev/null and b/src/test/resources/multiple-faces-5.jpg differ