-
Notifications
You must be signed in to change notification settings - Fork 968
Add enforce_utf8_boundaries option to BpeTrainer #1830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,14 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; | |
use crate::parallelism::*; | ||
use crate::tokenizer::{AddedToken, Result, Trainer}; | ||
use crate::utils::progress::{ProgressBar, ProgressStyle}; | ||
use crate::pre_tokenizers::byte_level::bytes_char; | ||
use ahash::{AHashMap, AHashSet}; | ||
use compact_str::CompactString; | ||
use dary_heap::OctonaryHeap; | ||
use serde::{Deserialize, Serialize}; | ||
use std::cmp::Ordering; | ||
use std::collections::HashSet; | ||
use std::sync::LazyLock; | ||
|
||
#[derive(Debug, Eq)] | ||
struct Merge { | ||
|
@@ -48,6 +50,7 @@ struct Config { | |
continuing_subword_prefix: Option<String>, | ||
end_of_word_suffix: Option<String>, | ||
max_token_length: Option<usize>, | ||
enforce_utf8_boundaries: bool, | ||
} | ||
|
||
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom | ||
|
@@ -69,6 +72,7 @@ impl Default for BpeTrainerBuilder { | |
continuing_subword_prefix: None, | ||
end_of_word_suffix: None, | ||
max_token_length: None, | ||
enforce_utf8_boundaries: false, | ||
}, | ||
} | ||
} | ||
|
@@ -144,6 +148,13 @@ impl BpeTrainerBuilder { | |
self | ||
} | ||
|
||
/// Whether to enforce UTF-8 character boundaries during merges | ||
#[must_use] | ||
pub fn enforce_utf8_boundaries(mut self, enforce: bool) -> Self { | ||
self.config.enforce_utf8_boundaries = enforce; | ||
self | ||
} | ||
|
||
/// Constructs the final BpeTrainer | ||
pub fn build(self) -> BpeTrainer { | ||
BpeTrainer { | ||
|
@@ -156,6 +167,7 @@ impl BpeTrainerBuilder { | |
continuing_subword_prefix: self.config.continuing_subword_prefix, | ||
end_of_word_suffix: self.config.end_of_word_suffix, | ||
max_token_length: self.config.max_token_length, | ||
enforce_utf8_boundaries: self.config.enforce_utf8_boundaries, | ||
words: AHashMap::new(), | ||
} | ||
} | ||
|
@@ -199,6 +211,11 @@ pub struct BpeTrainer { | |
pub end_of_word_suffix: Option<String>, | ||
/// An optional parameter to limit the max length of any single token | ||
pub max_token_length: Option<usize>, | ||
/// Whether to enforce UTF-8 character boundaries during merges. When true, only allows merging: | ||
/// 1. Complete UTF-8 characters with each other | ||
/// 2. Single bytes that are part of the same UTF-8 character, from left to right | ||
/// This is useful to avoid creating tokens that are not valid UTF-8 sequences, at no cost to compression. | ||
pub enforce_utf8_boundaries: bool, | ||
|
||
words: AHashMap<CompactString, u64>, | ||
} | ||
|
@@ -209,7 +226,12 @@ impl Default for BpeTrainer { | |
} | ||
} | ||
|
||
/// for utf8 boundaries, we need to map gpt2 encoded bytes back | ||
static CHAR_BYTES: LazyLock<AHashMap<char, u8>> = | ||
LazyLock::new(|| bytes_char().into_iter().map(|(b, c)| (c, b)).collect()); | ||
|
||
impl BpeTrainer { | ||
|
||
pub fn new(min_frequency: u64, vocab_size: usize) -> Self { | ||
Self { | ||
min_frequency, | ||
|
@@ -270,6 +292,68 @@ impl BpeTrainer { | |
} | ||
} | ||
|
||
/// helper for is_merge_allowed, to get the original bytes of a part | ||
fn get_original_bytes(&self, part: &str) -> Option<Vec<u8>> { | ||
part.chars().map(|c| CHAR_BYTES.get(&c).copied()).collect() | ||
} | ||
/// Determines if a merge is allowed under UTF-8 boundary constraints. | ||
/// | ||
/// This check is only performed if `enforce_utf8_boundaries` is true. | ||
/// A merge is allowed if it meets one of the following criteria: | ||
/// 1. Both tokens consist of complete characters. | ||
/// 2. Both tokens are part of the same single character, and the second is a single byte. | ||
/// This allows building multi-byte characters from their individual bytes left-to-right. | ||
/// All other combinations, such as merging a complete character with a partial byte, are disallowed. | ||
/// This function is designed to work on the character-mapped output of a `ByteLevel` | ||
/// pre-tokenizer by reversing the mapping to check the original bytes. | ||
/// Determines if a merge is allowed under UTF-8 boundary constraints. | ||
/// This function is designed to work on the character-mapped output of a `ByteLevel` | ||
/// pre-tokenizer by reversing the mapping to check the original bytes. | ||
fn is_merge_allowed(&self, pair: &Pair, id_to_word: &[CompactString]) -> bool { | ||
if !self.enforce_utf8_boundaries { | ||
return true; | ||
} | ||
|
||
let part_a = &id_to_word[pair.0 as usize]; | ||
let part_b = &id_to_word[pair.1 as usize]; | ||
|
||
// Get the original bytes by reversing the ByteLevel character mapping. | ||
let bytes_a = self.get_original_bytes(part_a.as_ref()).unwrap_or_default(); | ||
let bytes_b = self.get_original_bytes(part_b.as_ref()).unwrap_or_default(); | ||
|
||
// A "complete" token is one whose underlying bytes form a valid UTF-8 string. | ||
// For ByteLevel, this means single-byte ASCII chars (like a space) are complete, | ||
// but single bytes from a multi-byte sequence (like 0xF0) are not. | ||
let is_a_complete = std::str::from_utf8(&bytes_a).is_ok(); | ||
let is_b_complete = std::str::from_utf8(&bytes_b).is_ok(); | ||
|
||
// Rule 1: Allow merging two complete tokens. | ||
if is_a_complete && is_b_complete { | ||
return true; | ||
} | ||
|
||
// Rule 3 (Implicit): Any mix of complete and incomplete is disallowed. | ||
if is_a_complete || is_b_complete { | ||
return false; | ||
} | ||
|
||
// Rule 2: Both tokens are incomplete. Allow merge only if building a valid | ||
// UTF-8 prefix by appending a single byte. | ||
if bytes_b.len() == 1 { | ||
let mut merged = bytes_a; | ||
merged.extend_from_slice(&bytes_b); | ||
match std::str::from_utf8(&merged) { | ||
// The merged bytes form one or more complete characters. Valid. | ||
Ok(_) => true, | ||
// The merged bytes are an incomplete but valid prefix. Valid. | ||
Err(e) => e.error_len().is_none(), | ||
} | ||
} else { | ||
// If part_b is not a single byte, it's not a valid continuation merge. | ||
false | ||
} | ||
} | ||
|
||
/// Compute the initial alphabet and limit it if relevant | ||
fn compute_alphabet( | ||
&self, | ||
|
@@ -455,7 +539,7 @@ impl BpeTrainer { | |
let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); | ||
where_to_update.drain().for_each(|(pair, pos)| { | ||
let count = pair_counts[&pair]; | ||
if count > 0 { | ||
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { | ||
queue.push(Merge { | ||
pair, | ||
count: count as u64, | ||
|
@@ -550,13 +634,13 @@ impl BpeTrainer { | |
for ((pair, change), iw) in changes { | ||
let count = change * counts[iw] as i32; | ||
*pair_counts.entry(pair).or_default() += count; | ||
if change > 0 { | ||
if change > 0 && self.is_merge_allowed(&pair, &id_to_word) { | ||
where_to_update.entry(pair).or_default().insert(iw); | ||
} | ||
} | ||
where_to_update.drain().for_each(|(pair, pos)| { | ||
let count = pair_counts[&pair]; | ||
if count > 0 { | ||
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { | ||
queue.push(Merge { | ||
pair, | ||
count: count as u64, | ||
|
@@ -644,8 +728,14 @@ impl Trainer for BpeTrainer { | |
#[cfg(test)] | ||
mod tests { | ||
use super::{BpeTrainer, Pair, BPE}; | ||
use crate::pre_tokenizers::byte_level::{bytes_char, ByteLevel}; | ||
use crate::tokenizer::{ | ||
OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, Trainer, | ||
}; | ||
use ahash::AHashMap; | ||
use compact_str::CompactString; | ||
use std::collections::HashMap; | ||
use std::sync::LazyLock; | ||
|
||
#[test] | ||
fn test_train() { | ||
|
@@ -762,6 +852,7 @@ mod tests { | |
) | ||
} | ||
} | ||
|
||
#[test] | ||
fn bpe_test_max_token_length_direct_assert() { | ||
/* more direct version of bpe_test_max_token_length test | ||
|
@@ -831,4 +922,102 @@ mod tests { | |
.collect(); | ||
assert_eq!(trained_vocab, expected_vocab) | ||
} | ||
|
||
// The CHAR_TO_BYTE mapping is kept here *only* for the debug printing helper, | ||
// to make the test output readable. It is not used in the core test logic. | ||
static BYTE_TO_CHAR: LazyLock<AHashMap<u8, char>> = LazyLock::new(bytes_char); | ||
static CHAR_TO_BYTE: LazyLock<AHashMap<char, u8>> = | ||
LazyLock::new(|| BYTE_TO_CHAR.iter().map(|(b, c)| (*c, *b)).collect()); | ||
|
||
|
||
#[test] | ||
fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() { | ||
/// A local helper to print the vocabulary with original hex byte representations for clarity. | ||
fn print_vocab_with_hex(vocab: &HashMap<String, u32>, title: &str) { | ||
println!("\n--- {} ---", title); | ||
let mut vocab_items: Vec<_> = vocab.iter().collect(); | ||
vocab_items.sort_by_key(|(_, id)| *id); | ||
for (token, id) in vocab_items { | ||
// De-mangle the token back to its original bytes for printing | ||
let bytes: Vec<String> = token | ||
.chars() | ||
.map(|c| format!("{:02X}", CHAR_TO_BYTE.get(&c).unwrap_or(&0))) | ||
.collect(); | ||
println!( | ||
"ID {:<3} Token: {:<12} Bytes: [{}]", | ||
id, | ||
format!("{:?}", token), | ||
bytes.join(" ") | ||
); | ||
} | ||
} | ||
|
||
// Use the actual ByteLevel pre-tokenizer to process the input string. | ||
let byte_level_pretok = ByteLevel::new(false, false, false); | ||
let process_fn = |s: &str| -> Result<Vec<String>> { | ||
let mut pretokenized = PreTokenizedString::from(s); | ||
byte_level_pretok.pre_tokenize(&mut pretokenized)?; | ||
Ok(pretokenized | ||
.get_splits(OffsetReferential::Original, OffsetType::Byte) | ||
.into_iter() | ||
.map(|(word, _, _)| word.to_string()) | ||
.collect()) | ||
}; | ||
|
||
let sequence = " 🤗 🦒 🐹 🦦 🤗 𝟑".to_string(); | ||
let vocab_size = 25; | ||
|
||
// --- Part 1: Unconstrained BPE --- | ||
let mut unconstrained_trainer = BpeTrainer::builder() | ||
.vocab_size(vocab_size) | ||
.show_progress(false) | ||
.enforce_utf8_boundaries(false) | ||
.build(); | ||
unconstrained_trainer | ||
.feed(std::iter::once(&sequence), &process_fn) | ||
.unwrap(); | ||
let mut unconstrained_model = BPE::default(); | ||
unconstrained_trainer | ||
.train(&mut unconstrained_model) | ||
.unwrap(); | ||
print_vocab_with_hex( | ||
&unconstrained_model.get_vocab(), | ||
"Unconstrained Vocabulary", | ||
); | ||
let invalid_merge_token: String = | ||
[BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect(); | ||
assert!( | ||
unconstrained_model | ||
.get_vocab() | ||
.contains_key(&invalid_merge_token), | ||
"Unconstrained vocab SHOULD contain the top frequency merge (bytes [20 F0])" | ||
); | ||
|
||
// --- Part 2: Constrained BPE --- | ||
let mut constrained_trainer = BpeTrainer::builder() | ||
.vocab_size(vocab_size) | ||
.show_progress(false) | ||
.enforce_utf8_boundaries(true) | ||
.build(); | ||
constrained_trainer | ||
.feed(std::iter::once(&sequence), &process_fn) | ||
.unwrap(); | ||
let mut constrained_model = BPE::default(); | ||
constrained_trainer.train(&mut constrained_model).unwrap(); | ||
print_vocab_with_hex(&constrained_model.get_vocab(), "Constrained Vocabulary"); | ||
|
||
let valid_merge_token: String = | ||
[BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect(); | ||
assert!( | ||
!constrained_model | ||
.get_vocab() | ||
.contains_key(&invalid_merge_token), | ||
"Constrained vocab MUST NOT contain the invalid merge (bytes [20 F0])" | ||
); | ||
assert!( | ||
constrained_model | ||
.get_vocab() | ||
.contains_key(&valid_merge_token), | ||
"Constrained vocab SHOULD contain the next valid merge (bytes [F0 9F])" | ||
); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ use crate::utils::macro_rules_attribute; | |
|
||
/// Converts bytes to unicode characters. | ||
/// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 | ||
pub(crate) fn bytes_char() -> AHashMap<u8, char> { | ||
pub fn bytes_char() -> AHashMap<u8, char> { | ||
|
||
let mut bs: Vec<u8> = vec![]; | ||
bs.extend(b'!'..=b'~'); | ||
bs.extend(b'\xA1'..=b'\xAC'); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be changed to use xor? I had a bit longer than necessary to grasp that the
||
works the same as^
because theis_a_complete && is_b_complete
check has already been done. With^
the operator matches the comment more closely.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure