Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions src/transformers/tokenization_mistral_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@
of returning overflowing tokens.
return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
Whether or not to return special tokens mask information.
return_offsets_mapping (`bool`, *optional*, defaults to `False`):
Whether or not to return `(char_start, char_end)` for each token.

This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
Python's tokenizer, this method will raise `NotImplementedError`.
return_length (`bool`, *optional*, defaults to `False`):
Whether or not to return the lengths of the encoded inputs.
verbose (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -176,6 +171,7 @@ class MistralCommonTokenizer(PushToHubMixin):
Supports the following methods from the `PreTrainedTokenizerBase` class:

- [`~MistralCommonTokenizer.get_vocab`]: Returns the vocabulary as a dictionary of token to index.
This is a lossy conversion for Tekkenizer as some decoding errors are collapsed into the same token.
- [`~MistralCommonTokenizer.encode`]: Encode a string to a list of integers.
- [`~MistralCommonTokenizer.decode`]: Decode a list of integers to a string.
- [`~MistralCommonTokenizer.batch_decode`]: Decode a batch of list of integers to a list of strings.
Expand Down Expand Up @@ -354,9 +350,13 @@ def get_vocab(self) -> dict[str, int]:
`Dict[str, int]`: The vocabulary.
"""
if self._cache_get_vocab is None:
self._cache_get_vocab = {
token: idx for idx, token in enumerate(self.tokenizer.instruct_tokenizer.tokenizer.vocab())
}
# We reverse the order to make sure that the first token is the one to be returned when there are multiple tokens with the same string representation.
vocab = self.tokenizer.instruct_tokenizer.tokenizer.vocab()
self._cache_get_vocab = {token: self._piece_to_id(token, False) for token in vocab}
# Order the dict.
self._cache_get_vocab = dict(
sorted(((k, v) for k, v in self._cache_get_vocab.items()), key=lambda x: x[1])
)
return self._cache_get_vocab

def __len__(self):
Expand Down Expand Up @@ -517,7 +517,7 @@ def batch_decode(

def _is_control_token(self, token_id: int) -> bool:
if self._tokenizer_type == MistralTokenizerType.spm:
return token_id in self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
return token_id in self.tokenizer.instruct_tokenizer.tokenizer._control_tokens
elif self._tokenizer_type == MistralTokenizerType.tekken:
return token_id < self.tokenizer.instruct_tokenizer.tokenizer.num_special_tokens
else:
Expand Down Expand Up @@ -563,15 +563,27 @@ def convert_ids_to_tokens(
return tokens[0]
return tokens

def _piece_to_id(self, piece: str) -> int:
def _tekken_piece_to_id(self, piece: str, warn: bool) -> int:
tekken_tokenizer = self.tokenizer.instruct_tokenizer.tokenizer
assert isinstance(tekken_tokenizer, Tekkenizer), type(tekken_tokenizer)

piece_bytes = piece.encode("utf-8")
shift = tekken_tokenizer.num_special_tokens
try:
return shift + tekken_tokenizer._tekken_token2id_nospecial[piece_bytes]
except KeyError:
piece_str = piece_bytes.decode("utf-8")
if piece_str in tekken_tokenizer._special_tokens_reverse_vocab:
return tekken_tokenizer._special_tokens_reverse_vocab[piece_str]
if warn:
logger.warning("Failed to convert token %s to id, replacing with <unk>", piece_bytes)
return tekken_tokenizer.unk_id

def _piece_to_id(self, piece: str, warn: bool) -> int:
if self._tokenizer_type == MistralTokenizerType.spm:
return self.tokenizer.instruct_tokenizer.tokenizer._model.piece_to_id(piece)
elif self._tokenizer_type == MistralTokenizerType.tekken:
pieces = self.tokenizer.instruct_tokenizer.tokenizer._model.encode(
piece, allowed_special="all", disallowed_special=set()
)
assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
return pieces[0]
return self._tekken_piece_to_id(piece, warn)
else:
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")

Expand All @@ -595,7 +607,7 @@ def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, lis

ids: list[int] = []
for token in tokens:
ids.append(self._piece_to_id(token))
ids.append(self._piece_to_id(token, True))

if one_token:
return ids[0]
Expand Down Expand Up @@ -647,13 +659,7 @@ def _encode_plus(
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
if kwargs:
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer._encode_plus`."
)

def get_input_ids(text):
if isinstance(text, str):
return self._text_to_ids(text, add_special_tokens)
Expand Down Expand Up @@ -699,10 +705,8 @@ def _batch_encode_plus(
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
Expand All @@ -712,13 +716,6 @@ def get_input_ids(text):
else:
raise ValueError("Input is not valid. Should be a string or a list/tuple of integers.")

if return_offsets_mapping:
raise NotImplementedError(
"return_offset_mapping is not available when using Python tokenizers. "
"To use this feature, change your tokenizer to one deriving from "
"transformers.PreTrainedTokenizerFast."
)

input_ids = []
for ids in batch_text:
input_ids.append(get_input_ids(ids))
Expand Down Expand Up @@ -746,7 +743,7 @@ def _all_special_ids(self) -> set[int]:
if self._tokenizer_type == MistralTokenizerType.tekken:
return {t["rank"] for t in self.tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
elif self._tokenizer_type == MistralTokenizerType.spm:
return self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
return self.tokenizer.instruct_tokenizer.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")

Expand Down Expand Up @@ -966,9 +963,7 @@ def _get_padding_truncation_strategies(
logger.warning(
"Truncation was not explicitly activated but `max_length` is provided a specific value, please"
" use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
" 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
" tokenizer you can select this strategy more precisely by providing a specific strategy to"
" `truncation`."
" 'longest_first' truncation strategy."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
Expand Down Expand Up @@ -1376,6 +1371,7 @@ def apply_chat_template(
self,
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
tools: Optional[list[Union[dict, Callable]]] = None,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
tokenize: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
Expand All @@ -1398,6 +1394,9 @@ def apply_chat_template(
giving the name, description and argument types for the tool. See our
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
for more information.
add_generation_prompt (`bool`, *optional*):
This argument is a no-op for `MistralCommonTokenizer`. However it cannot be used at the same time as `continue_final_message` to keep the API consistent and
if any conversation ends with an assistant message, it will raise an error. In such case, use `continue_final_message` instead.
continue_final_message (bool, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
Expand Down Expand Up @@ -1442,6 +1441,9 @@ def apply_chat_template(
if not isinstance(truncation, bool):
raise TypeError("`truncation` must be a boolean for `apply_chat_template` method.")

if add_generation_prompt and continue_final_message:
raise ValueError("Cannot use both `add_generation_prompt` and `continue_final_message`.")

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
):
Expand All @@ -1451,6 +1453,14 @@ def apply_chat_template(
conversations = [conversation]
is_batched = False

if add_generation_prompt:
for conversation in conversations:
last_message = conversation[-1]
if last_message.get("role") == "assistant":
raise ValueError(
"The last message in the conversation is already an assistant message. Consider using `continue_final_message` instead."
)

def _maybe_adapt_message(message: dict[str, Any]) -> None:
"""Adapt message to `mistral-common` format and leave validation to `mistral-common`."""
if not isinstance(message, dict):
Expand Down Expand Up @@ -1543,7 +1553,7 @@ def _maybe_adapt_message(message: dict[str, Any]) -> None:
"Unable to convert output to PyTorch tensors format, PyTorch is not installed."
)

pixel_values = torch.tensor(images)
pixel_values = torch.from_numpy(np.stack(images))
elif return_tensors == "np":
pixel_values = np.array(images)
elif return_tensors is None:
Expand Down Expand Up @@ -1667,7 +1677,6 @@ def _is_valid_text_input(t):
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
**kwargs,
)
else:
return self._encode_plus(
Expand All @@ -1685,7 +1694,6 @@ def _is_valid_text_input(t):
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
**kwargs,
)

@classmethod
Expand Down Expand Up @@ -1760,9 +1768,10 @@ def from_pretrained(
raise ValueError("`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`.")

# Handle kwargs and AutoTokenizer case
if kwargs and not set(kwargs.keys()).issubset({"_from_auto", "trust_remote_code"}):
ignore_subset = {"_from_auto", "trust_remote_code"}
if kwargs and not (set_kwargs := set(kwargs.keys())).issubset(ignore_subset):
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
f"Kwargs {list(set_kwargs - ignore_subset)} are not supported by `MistralCommonTokenizer.from_pretrained`."
)

if not os.path.isdir(pretrained_model_name_or_path):
Expand Down
Loading