Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ class QuantizedType : public Type {
return -getDefaultMaximumForF8E5M2();
}

static constexpr int64_t getDefaultMaximumForF4E2M1FN() { return 6; }

static constexpr int64_t getDefaultMinimumForF4E2M1FN() {
return -getDefaultMaximumForF4E2M1FN();
}

/// Gets the original expressed type that this quantized type approximates.
/// Note that this presumes that the quantized type was always derived from
/// a floating point type, which in the broadest definition, is not true (i.e.
Expand Down Expand Up @@ -267,7 +273,7 @@ class AnyQuantizedType
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
Expand Down Expand Up @@ -327,7 +333,7 @@ class UniformQuantizedType
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
Expand Down Expand Up @@ -414,7 +420,7 @@ class UniformQuantizedPerAxisType
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
/// ScaleZero ::= Scale (':' ZeroPoint)?
///
/// StorageType: 'i'|'u' NumBits
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// AxisSpec: An integer value
/// BlockSizeSpec: An integer value
Expand Down Expand Up @@ -533,16 +539,16 @@ class UniformQuantizedSubChannelType

/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
/// look up table array of quantile values. The type of the data in the look up table is determined by
/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
/// the quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Quantiles: Quantile+
/// Quantile: A legal double value
Expand Down Expand Up @@ -600,16 +606,16 @@ class QuantileQuantizedType

/// Represents per-axis QuantileQuantizedType (also known as per-channel
/// quantization). The type of the data in the look up table is determined by the
/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
/// quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// Quantiles: Quantile+
Expand Down
27 changes: 14 additions & 13 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
} else if (storageType.isa<Float8E5M2Type>()) {
} else if (mlir::isa<Float8E5M2Type>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
} else if (storageType.isa<Float8E4M3FNType>()) {
} else if (mlir::isa<Float8E4M3FNType>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
} else if (mlir::isa<Float4E2M1FNType>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN();
defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN();
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
}

// Verify storageTypeMin and storageTypeMax.
Expand Down Expand Up @@ -574,19 +577,18 @@ LogicalResult QuantileQuantizedType::verifyInvariants(
unsigned typeWidth{};
if (storageType.isa<IntegerType>()) {
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
} else if (storageType.isa<Float8E5M2Type>() ||
storageType.isa<Float8E4M3FNType>()) {
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(storageType)) {
// Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType.
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
}

const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
const size_t typeWidthSize = 1 << typeWidth;
const size_t expectedSize =
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
(storageTypeRange < typeWidthSize) && !mlir::isa<FloatType>(storageType) ? storageTypeRange : typeWidthSize;

const auto quantileArraySize = quantiles.size();
if (quantileArraySize != expectedSize) {
Expand Down Expand Up @@ -660,19 +662,18 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
unsigned typeWidth{};
if (storageType.isa<IntegerType>()) {
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
} else if (storageType.isa<Float8E5M2Type>() ||
storageType.isa<Float8E4M3FNType>()) {
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(storageType)) {
// Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType.
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
"types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType ";
}

const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
const size_t typeWidthSize = 1 << typeWidth;
const size_t expectedSize =
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
(storageTypeRange < typeWidthSize) && !mlir::isa<FloatType>(storageType) ? storageTypeRange : typeWidthSize;

const auto quantileArraySize = quantiles.size();
if (quantileArraySize != expectedSize) {
Expand Down
32 changes: 21 additions & 11 deletions mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
isSigned = !intType.isUnsigned();
storageTypeWidth = intType.getWidth();
} else if (llvm::dyn_cast<Float8E5M2Type>(type) ||
llvm::dyn_cast<Float8E4M3FNType>(type)) {
storageTypeWidth = 8;
} else if (mlir::isa<Float8E5M2Type, Float8E4M3FNType, Float4E2M1FNType>(type)) {
storageTypeWidth = llvm::dyn_cast<FloatType>(type).getWidth();
isSigned = true;
} else {
parser.emitError(typeLoc, "illegal quantized storage type alias");
Expand Down Expand Up @@ -132,12 +131,15 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
} else if (storageType.isa<Float8E5M2Type>()) {
} else if (mlir::isa<Float8E5M2Type>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
} else if (storageType.isa<Float8E4M3FNType>()) {
} else if (mlir::isa<Float8E4M3FNType>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
} else if (mlir::isa<Float4E2M1FNType>(storageType)) {
defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN();
defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN();
} else {
defaultMin = std::numeric_limits<int64_t>::max();
defaultMax = std::numeric_limits<int64_t>::min();
Expand All @@ -150,7 +152,7 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
}

// Explicit storage min and storage max.
// F8 min and max values are integers, so parseInteger() is used.
// F8 and F4 min and max values are integers, so parseInteger() is used.
SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
parser.getCurrentLocation(&maxLoc) ||
Expand Down Expand Up @@ -382,7 +384,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// block-size-info `,` scale-zero-tensor `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
/// expressed-type-spec ::= `:` `f` integer-literal
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= scale (`:` zero-point)?
Expand All @@ -407,9 +409,9 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-list `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN`
/// quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` |
/// `f8E4M3FN`)
/// `f8E4M3FN` | `f4E2M1FN`)
/// expressed-type-spec ::= `:` `f` integer-literal axis-spec ::=
/// `:` integer-literal quantiles-list ::= `{` quantile (`,` quantile)* `}`
/// scale-zero ::= `:` float-literal `:` integer-literal
Expand Down Expand Up @@ -641,6 +643,8 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
out << "f8E5M2";
} else if (type.getStorageType().isa<Float8E4M3FNType>()) {
out << "f8E4M3FN";
} else if (type.getStorageType().isa<Float4E2M1FNType>()) {
out << "f4E2M1FN";
} else if (isSigned) {
out << "i" << storageWidth;
} else {
Expand All @@ -655,7 +659,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
? QuantizedType::getDefaultMinimumForF8E5M2()
: type.getStorageType().isa<Float8E4M3FNType>()
? QuantizedType::getDefaultMinimumForF8E4M3FN()
: std::numeric_limits<int64_t>::max();
: type.getStorageType().isa<Float4E2M1FNType>()
? QuantizedType::getDefaultMinimumForF4E2M1FN()
: std::numeric_limits<int64_t>::max();

int64_t defaultMax =
type.getStorageType().isa<IntegerType>()
Expand All @@ -664,7 +670,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
? QuantizedType::getDefaultMaximumForF8E5M2()
: type.getStorageType().isa<Float8E4M3FNType>()
? QuantizedType::getDefaultMaximumForF8E4M3FN()
: std::numeric_limits<int64_t>::min();
: type.getStorageType().isa<Float4E2M1FNType>()
? QuantizedType::getDefaultMaximumForF4E2M1FN()
: std::numeric_limits<int64_t>::min();

if (defaultMin != type.getStorageTypeMin() ||
defaultMax != type.getStorageTypeMax()) {
Expand All @@ -685,6 +693,8 @@ static void printQuantileType(Type quantileType, DialectAsmPrinter &out) {
out << ":f8E5M2";
} else if (quantileType.isa<Float8E4M3FNType>()) {
out << ":f8E4M3FN";
} else if (quantileType.isa<Float4E2M1FNType>()) {
out << ":f4E2M1FN";
} else {
// Float types
out << ":" << quantileType;
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Quant/parse-quantile-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ func.func @parse() -> !qalias {
// expected-error@+1 {{illegal storage type minimum: -500}}
!qalias = !quant.quantile<f8E4M3FN<-500:448>:f16:f32, {-1.0,1.0}:0.99872:127>

// -----
// Illegal storage min/max: max > defaultMax
// expected-error@+1 {{illegal storage type maximum: 10}}
!qalias = !quant.quantile<f4E2M1FN<-6:10>:f16:f32, {-1.0,1.0}:0.99872:127>

// -----
// Illegal storage min/max: min < defaultMin
// expected-error@+1 {{illegal storage type minimum: -10}}
!qalias = !quant.quantile<f4E2M1FN<-10:6>:f16:f32, {-1.0,1.0}:0.99872:127>

// -----
// Illegal uniform params: invalid scale
// expected-error@+1 {{expected floating point literal}}
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Quant/parse-quantile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func.func @parse() -> !qalias {
return %0 : !qalias
}

// -----
// Default min/max value optimization for f4E2M1FN.
// CHECK: !quant.quantile<f4E2M1FN:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:9.987200e-01:127>
!qalias = !quant.quantile<f4E2M1FN<-6:6>:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 >
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}

// -----
// Required per-layer params specified:
// [unsigned] storageType, expressedType, scale
Expand Down Expand Up @@ -92,6 +101,15 @@ func.func @parse() -> !qalias {
return %0 : !qalias
}

// -----
// Storage type: f4E2M1FN
// CHECK: !quant.quantile<f4E2M1FN:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:2.000000e+02>
!qalias = !quant.quantile<f4E2M1FN:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:2.0e+2>
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}

// -----
// Expressed type: f32
// CHECK: !quant.quantile<u4:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:2.000000e+02>
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@
// expected-error@+1 {{illegal storage type minimum: -500}}
!qalias = !quant.uniform<f8E4M3FN<-500:448>:f32, 0.99872:127>

// -----
// Illegal storage min/max: max > defaultMax
// expected-error@+1 {{illegal storage type maximum: 10}}
!qalias = !quant.uniform<f4E2M1FN<-6:10>:f32, 0.99872:127>

// -----
// Illegal storage min/max: min < defaultMin
// expected-error@+1 {{illegal storage type minimum: -10}}
!qalias = !quant.uniform<f4E2M1FN<-10:6>:f32, 0.99872:127>

// -----
// Illegal uniform params: invalid scale
// expected-error@+1 {{expected floating point literal}}
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Quant/parse-uniform.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func.func @parse() -> !qalias {
return %0 : !qalias
}

// -----
// Default min/max value optimization for f4E2M1FN.
// CHECK: !quant.uniform<f4E2M1FN:f32, 9.987200e-01:127>
!qalias = !quant.uniform<f4E2M1FN<-6:6>:f32, 0.99872:127 >
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}

// -----
// Required per-layer params specified:
// [unsigned] storageType, expressedType, scale
Expand Down Expand Up @@ -92,6 +101,15 @@ func.func @parse() -> !qalias {
return %0 : !qalias
}

// -----
// Storage type: f4E2M1FN
// CHECK: !quant.uniform<f4E2M1FN:f32, 2.000000e+02>
!qalias = !quant.uniform<f4E2M1FN:f32, 2.0e+2>
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}

// -----
// Storage type: i16
// CHECK: !quant.uniform<i16:f32, 2.000000e+02>
Expand Down
Loading