Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ impl PyBpeTrainer {
}
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
"max_token_length" => builder = builder.max_token_length(val.extract()?),
"enforce_utf8_boundaries" => builder = builder.enforce_utf8_boundaries(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
builder = builder.initial_alphabet(
Expand Down
43 changes: 43 additions & 0 deletions bindings/python/tests/bindings/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,49 @@ def test_can_pickle(self):
)


def test_enforce_utf8_boundaries(self):
# This input is designed to have a very frequent but invalid merge candidate:
# a space (0x20) followed by the first byte of different 4-byte encodings (0xF0).
# A less frequent but valid candidate is the first two bytes of an emoji (0xF0, 0x9F).
data = [" 🤗"] * 10 + [" 𝟑"] * 9

# Setup a tokenizer with a ByteLevel pre-tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# 1. Train with `enforce_utf8_boundaries=False` (unconstrained)
unconstrained_trainer = trainers.BpeTrainer(
vocab_size=260,
special_tokens=["<unk>"],
enforce_utf8_boundaries=False,
show_progress=False,
)
tokenizer.train_from_iterator(data, trainer=unconstrained_trainer)
vocab = tokenizer.get_vocab()

# The pre-tokenizer maps byte 0x20 to `Ġ` and 0xF0 to `ð`.
# The invalid merge of these two should be present.
invalid_token = "Ġð" # Bytes: [20, F0]
assert invalid_token in vocab, "Unconstrained trainer should learn the invalid merge"

# 2. Train with `enforce_utf8_boundaries=True` (constrained)
# We must re-initialize the tokenizer to start with a fresh model
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# Train with enforce_utf8_boundaries=True
constrained_trainer = trainers.BpeTrainer(
vocab_size=260,
special_tokens=["<unk>"],
enforce_utf8_boundaries=True,
show_progress=False,
)
tokenizer.train_from_iterator(data, trainer=constrained_trainer)
vocab = tokenizer.get_vocab()

# The invalid merge should not be present when enforcing UTF-8 boundaries
assert invalid_token not in vocab, "Constrained trainer should not learn invalid merges"

class TestWordPieceTrainer:
def test_can_modify(self):
trainer = trainers.WordPieceTrainer(
Expand Down
195 changes: 192 additions & 3 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::pre_tokenizers::byte_level::bytes_char;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use dary_heap::OctonaryHeap;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::sync::LazyLock;

#[derive(Debug, Eq)]
struct Merge {
Expand Down Expand Up @@ -48,6 +50,7 @@ struct Config {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
max_token_length: Option<usize>,
enforce_utf8_boundaries: bool,
}

/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
Expand All @@ -69,6 +72,7 @@ impl Default for BpeTrainerBuilder {
continuing_subword_prefix: None,
end_of_word_suffix: None,
max_token_length: None,
enforce_utf8_boundaries: false,
},
}
}
Expand Down Expand Up @@ -144,6 +148,13 @@ impl BpeTrainerBuilder {
self
}

/// Whether to enforce UTF-8 character boundaries during merges
#[must_use]
pub fn enforce_utf8_boundaries(mut self, enforce: bool) -> Self {
self.config.enforce_utf8_boundaries = enforce;
self
}

/// Constructs the final BpeTrainer
pub fn build(self) -> BpeTrainer {
BpeTrainer {
Expand All @@ -156,6 +167,7 @@ impl BpeTrainerBuilder {
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
max_token_length: self.config.max_token_length,
enforce_utf8_boundaries: self.config.enforce_utf8_boundaries,
words: AHashMap::new(),
}
}
Expand Down Expand Up @@ -199,6 +211,11 @@ pub struct BpeTrainer {
pub end_of_word_suffix: Option<String>,
/// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>,
/// Whether to enforce UTF-8 character boundaries during merges. When true, only allows merging:
/// 1. Complete UTF-8 characters with each other
/// 2. Single bytes that are part of the same UTF-8 character, from left to right
/// This is useful to avoid creating tokens that are not valid UTF-8 sequences, at no cost to compression.
pub enforce_utf8_boundaries: bool,

words: AHashMap<CompactString, u64>,
}
Expand All @@ -209,7 +226,12 @@ impl Default for BpeTrainer {
}
}

/// for utf8 boundaries, we need to map gpt2 encoded bytes back
static CHAR_BYTES: LazyLock<AHashMap<char, u8>> =
LazyLock::new(|| bytes_char().into_iter().map(|(b, c)| (c, b)).collect());

impl BpeTrainer {

pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
Self {
min_frequency,
Expand Down Expand Up @@ -270,6 +292,68 @@ impl BpeTrainer {
}
}

/// helper for is_merge_allowed, to get the original bytes of a part
fn get_original_bytes(&self, part: &str) -> Option<Vec<u8>> {
part.chars().map(|c| CHAR_BYTES.get(&c).copied()).collect()
}
/// Determines if a merge is allowed under UTF-8 boundary constraints.
///
/// This check is only performed if `enforce_utf8_boundaries` is true.
/// A merge is allowed if it meets one of the following criteria:
/// 1. Both tokens consist of complete characters.
/// 2. Both tokens are part of the same single character, and the second is a single byte.
/// This allows building multi-byte characters from their individual bytes left-to-right.
/// All other combinations, such as merging a complete character with a partial byte, are disallowed.
/// This function is designed to work on the character-mapped output of a `ByteLevel`
/// pre-tokenizer by reversing the mapping to check the original bytes.
/// Determines if a merge is allowed under UTF-8 boundary constraints.
/// This function is designed to work on the character-mapped output of a `ByteLevel`
/// pre-tokenizer by reversing the mapping to check the original bytes.
fn is_merge_allowed(&self, pair: &Pair, id_to_word: &[CompactString]) -> bool {
if !self.enforce_utf8_boundaries {
return true;
}

let part_a = &id_to_word[pair.0 as usize];
let part_b = &id_to_word[pair.1 as usize];

// Get the original bytes by reversing the ByteLevel character mapping.
let bytes_a = self.get_original_bytes(part_a.as_ref()).unwrap_or_default();
let bytes_b = self.get_original_bytes(part_b.as_ref()).unwrap_or_default();

// A "complete" token is one whose underlying bytes form a valid UTF-8 string.
// For ByteLevel, this means single-byte ASCII chars (like a space) are complete,
// but single bytes from a multi-byte sequence (like 0xF0) are not.
let is_a_complete = std::str::from_utf8(&bytes_a).is_ok();
let is_b_complete = std::str::from_utf8(&bytes_b).is_ok();

// Rule 1: Allow merging two complete tokens.
if is_a_complete && is_b_complete {
return true;
}

// Rule 3 (Implicit): Any mix of complete and incomplete is disallowed.
if is_a_complete || is_b_complete {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be changed to use xor? I had a bit longer than necessary to grasp that the || works the same as ^ because the is_a_complete && is_b_complete check has already been done. With ^ the operator matches the comment more closely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


// Rule 2: Both tokens are incomplete. Allow merge only if building a valid
// UTF-8 prefix by appending a single byte.
if bytes_b.len() == 1 {
let mut merged = bytes_a;
merged.extend_from_slice(&bytes_b);
match std::str::from_utf8(&merged) {
// The merged bytes form one or more complete characters. Valid.
Ok(_) => true,
// The merged bytes are an incomplete but valid prefix. Valid.
Err(e) => e.error_len().is_none(),
}
} else {
// If part_b is not a single byte, it's not a valid continuation merge.
false
}
}

/// Compute the initial alphabet and limit it if relevant
fn compute_alphabet(
&self,
Expand Down Expand Up @@ -455,7 +539,7 @@ impl BpeTrainer {
let mut queue = OctonaryHeap::with_capacity(pair_counts.len());
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) {
queue.push(Merge {
pair,
count: count as u64,
Expand Down Expand Up @@ -550,13 +634,13 @@ impl BpeTrainer {
for ((pair, change), iw) in changes {
let count = change * counts[iw] as i32;
*pair_counts.entry(pair).or_default() += count;
if change > 0 {
if change > 0 && self.is_merge_allowed(&pair, &id_to_word) {
where_to_update.entry(pair).or_default().insert(iw);
}
}
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) {
queue.push(Merge {
pair,
count: count as u64,
Expand Down Expand Up @@ -644,8 +728,14 @@ impl Trainer for BpeTrainer {
#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use crate::pre_tokenizers::byte_level::{bytes_char, ByteLevel};
use crate::tokenizer::{
OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, Trainer,
};
use ahash::AHashMap;
use compact_str::CompactString;
use std::collections::HashMap;
use std::sync::LazyLock;

#[test]
fn test_train() {
Expand Down Expand Up @@ -762,6 +852,7 @@ mod tests {
)
}
}

#[test]
fn bpe_test_max_token_length_direct_assert() {
/* more direct version of bpe_test_max_token_length test
Expand Down Expand Up @@ -831,4 +922,102 @@ mod tests {
.collect();
assert_eq!(trained_vocab, expected_vocab)
}

// The CHAR_TO_BYTE mapping is kept here *only* for the debug printing helper,
// to make the test output readable. It is not used in the core test logic.
static BYTE_TO_CHAR: LazyLock<AHashMap<u8, char>> = LazyLock::new(bytes_char);
static CHAR_TO_BYTE: LazyLock<AHashMap<char, u8>> =
LazyLock::new(|| BYTE_TO_CHAR.iter().map(|(b, c)| (*c, *b)).collect());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reuse the statics already in the codebase, see other comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, just missed them I think :)


#[test]
fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() {
/// A local helper to print the vocabulary with original hex byte representations for clarity.
fn print_vocab_with_hex(vocab: &HashMap<String, u32>, title: &str) {
println!("\n--- {} ---", title);
let mut vocab_items: Vec<_> = vocab.iter().collect();
vocab_items.sort_by_key(|(_, id)| *id);
for (token, id) in vocab_items {
// De-mangle the token back to its original bytes for printing
let bytes: Vec<String> = token
.chars()
.map(|c| format!("{:02X}", CHAR_TO_BYTE.get(&c).unwrap_or(&0)))
.collect();
println!(
"ID {:<3} Token: {:<12} Bytes: [{}]",
id,
format!("{:?}", token),
bytes.join(" ")
);
}
}

// Use the actual ByteLevel pre-tokenizer to process the input string.
let byte_level_pretok = ByteLevel::new(false, false, false);
let process_fn = |s: &str| -> Result<Vec<String>> {
let mut pretokenized = PreTokenizedString::from(s);
byte_level_pretok.pre_tokenize(&mut pretokenized)?;
Ok(pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(word, _, _)| word.to_string())
.collect())
};

let sequence = " 🤗 🦒 🐹 🦦 🤗 𝟑".to_string();
let vocab_size = 25;

// --- Part 1: Unconstrained BPE ---
let mut unconstrained_trainer = BpeTrainer::builder()
.vocab_size(vocab_size)
.show_progress(false)
.enforce_utf8_boundaries(false)
.build();
unconstrained_trainer
.feed(std::iter::once(&sequence), &process_fn)
.unwrap();
let mut unconstrained_model = BPE::default();
unconstrained_trainer
.train(&mut unconstrained_model)
.unwrap();
print_vocab_with_hex(
&unconstrained_model.get_vocab(),
"Unconstrained Vocabulary",
);
let invalid_merge_token: String =
[BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect();
assert!(
unconstrained_model
.get_vocab()
.contains_key(&invalid_merge_token),
"Unconstrained vocab SHOULD contain the top frequency merge (bytes [20 F0])"
);

// --- Part 2: Constrained BPE ---
let mut constrained_trainer = BpeTrainer::builder()
.vocab_size(vocab_size)
.show_progress(false)
.enforce_utf8_boundaries(true)
.build();
constrained_trainer
.feed(std::iter::once(&sequence), &process_fn)
.unwrap();
let mut constrained_model = BPE::default();
constrained_trainer.train(&mut constrained_model).unwrap();
print_vocab_with_hex(&constrained_model.get_vocab(), "Constrained Vocabulary");

let valid_merge_token: String =
[BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect();
assert!(
!constrained_model
.get_vocab()
.contains_key(&invalid_merge_token),
"Constrained vocab MUST NOT contain the invalid merge (bytes [20 F0])"
);
assert!(
constrained_model
.get_vocab()
.contains_key(&valid_merge_token),
"Constrained vocab SHOULD contain the next valid merge (bytes [F0 9F])"
);
}
}
2 changes: 1 addition & 1 deletion tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::utils::macro_rules_attribute;

/// Converts bytes to unicode characters.
/// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
pub(crate) fn bytes_char() -> AHashMap<u8, char> {
pub fn bytes_char() -> AHashMap<u8, char> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this made pub, this seems like a mistake. In fact, it seems even the previous pub(crate) is misplaced. Can you make this private and reuse the CHAR_BYTES static in this module by making it pub(crate)?

Copy link
Contributor Author

@sanderland sanderland Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function is used in normalizers/byte_level.rs, but will make it pub(crate) again and use CHAR_BYTES

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but normalizers/byte_level.rs only needs the function to build an exact copy of the static BYTE_TO_CHAR. So as a drive-by cleanup, we could make just the statics pub(crate) and reuse them across both pre_tokenizers/byte_level.rs and normalizers/byte_level.rs. We can then make this function private.

let mut bs: Vec<u8> = vec![];
bs.extend(b'!'..=b'~');
bs.extend(b'\xA1'..=b'\xAC');
Expand Down
Loading