diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 01546034..c8bba384 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1949,6 +1949,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool: """ # Get the names of layers in quantization blocks supported_types = self.supported_types + regex_config = {} layers_in_blocks = get_layer_names_in_block( self.model, supported_types, self.quant_block_list, self.inner_supported_types ) @@ -1977,6 +1978,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool: matched_names.append(layer_name) if len(matched_names) > 0: val = layer_config[name] + regex_config[name] = val # keep regex config layer_config.pop(name) for match_name in matched_names: layer_config[match_name] = val @@ -2067,6 +2069,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool: if need_to_quantize_lm_head: has_qlayer_outside_block = True + self.regex_config = regex_config # Return whether there are quantized layers outside the blocks return has_qlayer_outside_block @@ -3160,6 +3163,7 @@ def save_quantized( "act_data_type", "super_bits", "super_group_size", + "regex_config", ] if isinstance(self.dataset, str): serialization_keys.append("dataset") diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index beeb3af1..ee46a69f 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -17,6 +17,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict import threadpoolctl as tctl @@ -48,6 +49,13 @@ import auto_round.export.export_to_autogptq.qlinear_triton from auto_round.export.utils import save_model + +GPTQ_REQUIRED_CONFIG_KEYS = ( + "bits", + "group_size", + "sym", +) + from auto_round.logger import logger from auto_round.utils import ( SUPPORTED_LAYER_TYPES, @@ -57,7 +65,9 @@ get_autogptq_packing_qlinear, get_block_names, get_module, + json_serialize, set_module, + to_standard_regex, ) BLOCK_PATTERNS = [ ## copy from transformers optimum @@ -68,6 +78,31 @@ ] +def convert_to_autogptq_dynamic(regex_config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """ + Convert AutoRound-style regex_config into AutoGPTQ-style QuantizerConfig.dynamic. + + Rules: + - bits < 16 -> quantize -> positive match `+:regex` + - bits == 16 -> skip quantize -> negative match `-:regex` + """ + converted = {} + for name, cfg in regex_config.items(): + bits = cfg.get("bits") + regex = to_standard_regex(name) + + if bits is None: + continue # ignore invalid entries + elif bits < 16: + converted[f"+:{regex}"] = {"bits": bits} + for key in GPTQ_REQUIRED_CONFIG_KEYS: # only save keys gptq supported + converted[f"+:{regex}"][key] = regex_config[name][key] + else: + # skip quantization + converted[f"-:{regex}"] = {} + return converted + + def pack_layer(name, model, backend, device=None): if name == "lm_head": ##dese not support lm-head return @@ -156,7 +191,8 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll logger.error("auto-gptq format may not support loading this quantized model") quantization_config["block_name_to_quantize"] = common_prefix quantization_config.pop("to_quant_block_names", None) - + regex_config = quantization_config.pop("regex_config") + quantization_config["dynamic"] = convert_to_autogptq_dynamic(regex_config) ## as layers maybe already packed, we need to check in layer_config layer_config = kwargs["layer_config"] for n, m in model.named_modules(): diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 1e4036b6..0de6a12b 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -42,6 +42,7 @@ is_nv_fp, is_standard_fp, set_module, + to_standard_regex, ) @@ -332,6 +333,10 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"] extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"] + extra_config[layer_name]["act_bits"] = layer_config[layer_name]["act_bits"] + extra_config[layer_name]["act_data_type"] = layer_config[layer_name]["act_data_type"] + extra_config[layer_name]["act_group_size"] = layer_config[layer_name]["act_group_size"] + extra_config[layer_name]["act_sym"] = layer_config[layer_name]["act_sym"] elif layer_config[layer_name]["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): @@ -343,6 +348,13 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for key in neq_keys: if layer_config[layer_name][key] is not None: extra_config[layer_name][key] = layer_config[layer_name][key] + + regex_config = quantization_config.pop("regex_config") + if regex_config is not None: + for name in regex_config.keys(): + regex_name = to_standard_regex(name) + extra_config[regex_name] = {**{k: regex_config[name][k] for k in REQUIRED_CONFIG_KEYS}} + if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index e5653169..349ff2b5 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -39,6 +39,7 @@ is_nv_fp, set_amax_for_all_moe_layers, set_module, + to_standard_regex, ) from auto_round.wrapper import WrapperWALayer @@ -215,6 +216,13 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for key in neq_keys: if layer_config[layer_name][key] is not None: extra_config[layer_name][key] = layer_config[layer_name][key] + + regex_config = quantization_config.pop("regex_config") + if regex_config is not None: + for name in regex_config.keys(): + regex_name = to_standard_regex(name) + extra_config[regex_name] = {**{k: regex_config[name][k] for k in REQUIRED_CONFIG_KEYS}} + if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/export/export_to_awq/export.py b/auto_round/export/export_to_awq/export.py index 6da6bec3..46ba11b3 100644 --- a/auto_round/export/export_to_awq/export.py +++ b/auto_round/export/export_to_awq/export.py @@ -137,6 +137,7 @@ def wrapper(name): return model quantization_config = kwargs["serialization_dict"] + quantization_config.pop("regex_config") # as awq do not support mixed bits config saving if output_dir is None: return compressed_model diff --git a/auto_round/export/export_to_llmcompressor/export_to_fp.py b/auto_round/export/export_to_llmcompressor/export_to_fp.py index 37899238..f2fd0515 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_fp.py @@ -25,6 +25,7 @@ from tqdm import tqdm from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear +from auto_round.export.export_to_llmcompressor.utils import generate_ignore_regex_list from auto_round.export.utils import save_model from auto_round.logger import logger from auto_round.utils import ( @@ -114,9 +115,8 @@ def pack_layer(name, model, backend, device=None): scale = layer.scale global_scale = getattr(layer, "weight_global_scale", None) input_global_scale = getattr(layer, "input_global_scale", None) - # zero = layer.zp + # zero = layer.zp # no zeros to handle, as mxfp not support asym quantization qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device) - ## no zeros to handle, as mxfp not support asym quantization qlayer.to(orig_device) @@ -155,6 +155,9 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): device = kwargs.get("device", None) tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) + ar_quantization_config = kwargs["serialization_dict"] + regex_config = ar_quantization_config.pop("regex_config") + layer_config = kwargs["layer_config"] extra_config = {} if act_bits <= 8: @@ -199,12 +202,7 @@ def wrapper(name): for _ in executor.map(wrapper, names): pass - # TODO fix the ignore re match issue, compile with fp8 & int8 config - ignore = ["lm_head"] - for layer_name in layer_config: - if layer_config[layer_name]["bits"] > 8: ## find ignore layers - ignore.append(layer_name) - ignore = list(set(ignore)) + ignore = generate_ignore_regex_list(regex_config=regex_config, layer_config=layer_config) # get llm-compressor format config check_compressed_tensors_supported() diff --git a/auto_round/export/export_to_llmcompressor/utils.py b/auto_round/export/export_to_llmcompressor/utils.py new file mode 100644 index 00000000..304a720b --- /dev/null +++ b/auto_round/export/export_to_llmcompressor/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +from auto_round.utils import matches_any_regex, to_standard_regex + + +def generate_ignore_regex_list(regex_config: Dict[str, Dict], layer_config: Dict[str, Dict]) -> List[str]: + """ + Generate ignore regex list for llm_compressor based on regex_config and layer_config. + + Rules: + 1. Any layer in regex_config with bits >= 16 is ignored. + 2. Any layer in layer_config with bits >= 16 is ignored if not already included. + 3. Output regex patterns are normalized for llm_compressor ('re:...' style). + + Args: + regex_config (Dict[str, Dict]): dynamic quantization config + layer_config (Dict[str, Dict]): layer-wise quantization config + + Returns: + List[str]: List of regex patterns to ignore during quantization. + """ + prefix = "re:" + ignore_regex: List[str] = [] + + # Step 1: Add regex_config keys with bits >= 16 + for key, cfg in regex_config.items(): + bits = cfg.get("bits") + if bits > 8: + ignore_regex.append(prefix + to_standard_regex(key)) + + # Step 2: Add all full named layer from layer_config if bits >= 16 + for key, cfg in layer_config.items(): + bits = cfg.get("bits") + if bits > 8: + ignore_regex.append(key) + + return ignore_regex diff --git a/auto_round/utils.py b/auto_round/utils.py index 9af09758..90f1f9f0 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -24,7 +24,7 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Tuple, Union +from typing import Any, Callable, List, Tuple, Union import cpuinfo import torch @@ -2755,3 +2755,79 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): return True return False + + +def to_standard_regex(pattern: str) -> str: + """ + Convert a user-specified string into a standardized regex for layer matching. + + Rules: + - If the pattern already contains regex tokens ('.*', '^', '$', etc.), + keep them as-is. + - Otherwise, wrap the pattern with `.*` on both sides to allow substring matching. + - Always ensure the returned regex is valid (compilable by re). + + Examples: + >>> to_standard_regex("model.embed_tokens") + '.*model\\.embed_tokens.*' + >>> to_standard_regex("mlp.gate") + '.*mlp\\.gate.*' + >>> to_standard_regex("mlp.gate$") + '.*mlp\\.gate$' + >>> to_standard_regex("mlp.*gate") + '.*mlp.*gate.*' + """ + # Heuristic: if pattern contains regex meta characters, assume partial regex + meta_chars = {".*", "^", "$", "|", "(", ")", "[", "]", "?", "+"} + has_regex = any(tok in pattern for tok in meta_chars) + if not has_regex: + # Escape literal dots, etc., and wrap with .* for substring matching + pattern = re.escape(pattern) + regex = f".*{pattern}.*" + else: + # Only escape bare dots that are not already part of regex constructs + # Avoid double escaping .* sequences + tmp = [] + i = 0 + while i < len(pattern): + if pattern[i] == ".": + if i + 1 < len(pattern) and pattern[i + 1] == "*": + tmp.append(".*") # keep regex token + i += 2 + continue + else: + tmp.append("\\.") # escape bare dot + else: + tmp.append(pattern[i]) + i += 1 + regex = "".join(tmp) + # If no anchors are provided, allow substring matching + if not regex.startswith("^") and not regex.startswith(".*"): + regex = ".*" + regex + if not regex.endswith("$") and not regex.endswith(".*"): + regex = regex + ".*" + # Validate regex + try: + re.compile(regex) + except re.error as e: + raise ValueError(f"Invalid regex generated from pattern '{pattern}': {e}") + return regex + + +def matches_any_regex(layer_name: str, regex_list: List[str], prefix="re:") -> bool: + """ + Check if layer_name matches any regex pattern in regex_list. + """ + for pattern in regex_list: + # Remove 're:' prefix for matching + pat = pattern.removeprefix(prefix) + if re.fullmatch(pat, layer_name): + return True + return False + + +def json_serialize(obj: Any): + """Convert non-JSON-serializable objects into JSON-friendly formats.""" + if isinstance(obj, torch.dtype): + return str(obj).split(".")[-1] # e.g., torch.float16 -> "float16" + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/test/test_cpu/test_mix_bits.py b/test/test_cpu/test_mix_bits.py new file mode 100644 index 00000000..d7a5c55d --- /dev/null +++ b/test/test_cpu/test_mix_bits.py @@ -0,0 +1,112 @@ +import os +import shutil +import sys +import unittest + +from parameterized import parameterized + +sys.path.insert(0, "../..") +import torch +from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer + +from auto_round import AutoRound +from auto_round.testing_utils import require_gptqmodel + + +def _get_folder_size(path: str) -> float: + """Return folder size in GB.""" + total_size = 0 + for dirpath, _, filenames in os.walk(path): + for f in filenames: + fp = os.path.join(dirpath, f) + if os.path.isfile(fp): + total_size += os.path.getsize(fp) + return total_size / (1024**3) # convert to GB + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRound(unittest.TestCase): + @classmethod + def setUpClass(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + self.save_dir = "./saved" + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.llm_dataloader = LLMDataLoader() + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + @require_gptqmodel + def test_mixed_gptqmodel(self): + bits, sym, group_size = 4, True, 128 + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "k_proj": {"bits": 8}, + "lm_head": {"bits": 16}, + "fc1": {"bits": 16}, + } + autoround = AutoRound( + model=model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + layer_config=layer_config, + dataset=self.llm_dataloader, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + from gptqmodel import GPTQModel + + model = GPTQModel.load(quantized_model_path) + assert model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8 + assert model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4 + result = model.generate("Uncovering deep insights begins with")[0] # tokens + assert "!!!" not in model.tokenizer.decode(result) # string output + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_mixed_autoround_format(self): + bits, sym, group_size = 4, True, 128 + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "k_proj": {"bits": 8}, + "q_proj": {"bits": 3}, + "lm_head": {"bits": 16}, + "fc1": {"bits": 16}, + } + autoround = AutoRound( + model=model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") + assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8 + assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3 + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + shutil.rmtree(quantized_model_path, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_cuda/test_mix_bits.py b/test/test_cuda/test_mix_bits.py new file mode 100644 index 00000000..7353af6b --- /dev/null +++ b/test/test_cuda/test_mix_bits.py @@ -0,0 +1,185 @@ +import os +import shutil +import sys +import unittest + +from parameterized import parameterized + +sys.path.insert(0, "../..") +import torch +from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer + +from auto_round import AutoRound +from auto_round.testing_utils import require_gptqmodel + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRound(unittest.TestCase): + @classmethod + def setUpClass(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + self.save_dir = "./saved" + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.llm_dataloader = LLMDataLoader() + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + @require_gptqmodel + def test_mixed_gptqmodel(self): + bits, sym, group_size = 4, True, 128 + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "k_proj": {"bits": 8}, + "lm_head": {"bits": 16}, + "fc1": {"bits": 16}, + } + autoround = AutoRound( + model=model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + layer_config=layer_config, + dataset=self.llm_dataloader, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq") + from gptqmodel import GPTQModel + + model = GPTQModel.load(quantized_model_path) + assert model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8 + assert model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4 + result = model.generate("Uncovering deep insights begins with")[0] # tokens + assert "!!!" not in model.tokenizer.decode(result) # string output + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_mixed_autoround_format(self): + bits, sym, group_size = 4, True, 128 + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "k_proj": {"bits": 8}, + "q_proj": {"bits": 3}, + "lm_head": {"bits": 16}, + "fc1": {"bits": 16}, + } + autoround = AutoRound( + model=model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") + assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8 + assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3 + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_mixed_autoround_format_vllm(self): + layer_config = { + "self_attn": {"bits": 8}, + "lm_head": {"bits": 16}, + } + autoround = AutoRound( + self.model, + self.tokenizer, + scheme="W4A16", + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + quantized_model_path = self.save_dir + autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round") + + from vllm import LLM, SamplingParams + + # Sample prompts. + prompts = [ + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM. + QUANTIZATION = "auto-round" # quantized_model_path + llm = LLM(model=quantized_model_path, quantization=QUANTIZATION, trust_remote_code=True, tensor_parallel_size=1) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + # if "France" in prompt: + assert "!!!" not in generated_text + print(f"{prompt}: {generated_text}") + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_mixed_llmcompressor_format_vllm(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "self_attn": {"bits": 16, "act_bits": 16, "data_type": "float"}, + "lm_head": {"bits": 16, "act_bits": 16, "data_type": "float"}, + "fc1": { + "bits": 16, + "act_bits": 16, + "data_type": "float", + }, + } + autoround = AutoRound( + model_name, + scheme="NVFP4", + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = self.save_dir + compressed, _ = autoround.quantize_and_save( + output_dir=quantized_model_path, inplace=False, format="llm_compressor" + ) + from vllm import LLM, SamplingParams + + # Sample prompts. + prompts = [ + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM. + QUANTIZATION = "auto-round" # quantized_model_path + llm = LLM(model=quantized_model_path, trust_remote_code=True, tensor_parallel_size=1) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"{prompt}: {generated_text}") + assert "!!!" not in generated_text + shutil.rmtree(quantized_model_path, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main()