diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/replacements/vectorapi/nodes/VectorAPICompressExpandOpNode.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/replacements/vectorapi/nodes/VectorAPICompressExpandOpNode.java index 6d50377dd2c6..ea0910756f73 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/replacements/vectorapi/nodes/VectorAPICompressExpandOpNode.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/replacements/vectorapi/nodes/VectorAPICompressExpandOpNode.java @@ -36,12 +36,17 @@ import jdk.graal.compiler.graph.NodeMap; import jdk.graal.compiler.nodeinfo.NodeInfo; import jdk.graal.compiler.nodes.FrameState; +import jdk.graal.compiler.nodes.NodeView; import jdk.graal.compiler.nodes.ValueNode; +import jdk.graal.compiler.nodes.calc.CompressBitsNode; import jdk.graal.compiler.nodes.spi.Canonicalizable; import jdk.graal.compiler.nodes.spi.CanonicalizerTool; import jdk.graal.compiler.nodes.spi.CoreProviders; import jdk.graal.compiler.replacements.nodes.MacroNode.MacroParams; import jdk.graal.compiler.vector.architecture.VectorArchitecture; +import jdk.graal.compiler.vector.nodes.amd64.IntegerToOpMaskNode; +import jdk.graal.compiler.vector.nodes.amd64.OpMaskToIntegerNode; +import jdk.graal.compiler.vector.nodes.simd.LogicValueStamp; import jdk.graal.compiler.vector.nodes.simd.SimdConstant; import jdk.graal.compiler.vector.nodes.simd.SimdCompressNode; import jdk.graal.compiler.vector.nodes.simd.SimdExpandNode; @@ -103,7 +108,7 @@ private ValueNode mask() { @Override public Iterable vectorInputs() { - return List.of(source(), mask()); + return source().isNullConstant() ? List.of(mask()) : List.of(source(), mask()); } @Override @@ -121,7 +126,7 @@ public Node canonical(CanonicalizerTool tool) { ValueNode[] args = toArgumentArray(); ObjectStamp newSpeciesStamp = improveResultBoxStamp(tool); - SimdStamp newVectorStamp = improveVectorStamp(vectorStamp, args, VCLASS_ARG_INDEX, ECLASS_ARG_INDEX, LENGTH_ARG_INDEX, tool); + SimdStamp newVectorStamp = improveResultStamp(vectorStamp, args, tool); if (newSpeciesStamp != speciesStamp || newVectorStamp != vectorStamp) { return new VectorAPICompressExpandOpNode(copyParamsWithImprovedStamp(newSpeciesStamp), newVectorStamp, null, stateAfter()); } @@ -138,7 +143,7 @@ public boolean canExpand(VectorArchitecture vectorArch, EconomicMap expanded) { int opr = opr().asJavaConstant().asInt(); - ValueNode src = expanded.get(source()); ValueNode mask = expanded.get(mask()); if (opr == COMPRESS_OP) { + ValueNode src = expanded.get(source()); return SimdCompressNode.create(src, mask); - } else { - GraalError.guarantee(opr == EXPAND_OP, "%d", opr); + } else if (opr == EXPAND_OP) { + ValueNode src = expanded.get(source()); return SimdExpandNode.create(src, mask); + } else { + GraalError.guarantee(opr == MASK_COMPRESS_OP, "unexpected opcode %d", opr); + ValueNode maskToInt = OpMaskToIntegerNode.create(mask); + ValueNode compressedInt = new CompressBitsNode(maskToInt, maskToInt); + return new IntegerToOpMaskNode(compressedInt, mask.stamp(NodeView.DEFAULT).unrestricted()); } }