From 07885d275e1b1faf81874a2ef80dd47dea803172 Mon Sep 17 00:00:00 2001 From: Mostafa Date: Wed, 6 Aug 2025 18:40:49 +0800 Subject: [PATCH] feat: add cli for tokenizer and training --- tokenizers/Cargo.toml | 7 ++ tokenizers/bin/tokenize.rs | 77 +++++++++++++++++++ tokenizers/src/models/wordpiece/mod.rs | 2 +- tokenizers/tests/cli.rs | 101 +++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 tokenizers/bin/tokenize.rs create mode 100644 tokenizers/tests/cli.rs diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 6ed8498cf..938fe98cb 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -69,6 +69,7 @@ monostate = "0.1.12" ahash = { version = "0.8.11", features = ["serde"] } dary_heap = { version = "0.3.6", features = ["serde"] } compact_str = { version = "0.9", features = ["serde"] } +clap = { version = "4.5.4", features = ["derive"] } [features] default = ["progressbar", "onig", "esaxx_fast"] @@ -79,6 +80,8 @@ unstable_wasm = ["fancy-regex", "getrandom/wasm_js"] rustls-tls = ["hf-hub?/rustls-tls"] [dev-dependencies] +assert_cmd = "2.0" +predicates = "3.0" criterion = "0.6" tempfile = "3.10" assert_approx_eq = "1.1" @@ -92,3 +95,7 @@ lto = "fat" name = "encode_batch" required-features = ["http"] +[[bin]] +name = "tokenize" +path = "bin/tokenize.rs" + diff --git a/tokenizers/bin/tokenize.rs b/tokenizers/bin/tokenize.rs new file mode 100644 index 000000000..0ddf577b7 --- /dev/null +++ b/tokenizers/bin/tokenize.rs @@ -0,0 +1,77 @@ +use clap::{Parser, Subcommand}; +use tokenizers::tokenizer::Tokenizer; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + /// Tokenize input text using a model file + Tokenize { + /// Path to the tokenizer model file (e.g., tokenizer.json) + #[arg(long)] + model: String, + /// Input text to tokenize + #[arg(long)] + text: String, + }, + /// Train a new BPE tokenizer model + Train { + /// Input text file(s) for training (comma-separated or repeated) + #[arg(long, required = true)] + files: Vec, + /// Vocabulary size + #[arg(long, default_value_t = 30000)] + vocab_size: usize, + /// Output path for the trained model (e.g., model.json) + #[arg(long)] + output: String, + }, +} + +fn main() { + let cli = Cli::parse(); + match cli.command { + Commands::Tokenize { model, text } => match Tokenizer::from_file(&model) { + Ok(tokenizer) => match tokenizer.encode(text.as_str(), true) { + Ok(encoding) => { + println!("Token IDs: {:?}", encoding.get_ids()); + } + Err(e) => { + eprintln!("Failed to encode text: {}", e); + std::process::exit(1); + } + }, + Err(e) => { + eprintln!("Failed to load tokenizer model: {}", e); + std::process::exit(1); + } + }, + Commands::Train { + files, + vocab_size, + output, + } => { + use tokenizers::models::bpe::{BpeTrainer, BPE}; + use tokenizers::models::ModelWrapper; + use tokenizers::models::TrainerWrapper; + + let mut tokenizer = Tokenizer::new(ModelWrapper::BPE(BPE::default())); + let mut trainer = + TrainerWrapper::BpeTrainer(BpeTrainer::builder().vocab_size(vocab_size).build()); + if let Err(e) = tokenizer.train_from_files(&mut trainer, files.clone()) { + eprintln!("Failed to train tokenizer: {}", e); + std::process::exit(1); + } + if let Err(e) = tokenizer.save(&output, true) { + eprintln!("Failed to save trained model: {}", e); + std::process::exit(1); + } + println!("Model trained and saved to {}", output); + } + } +} diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index fa5c3e775..61fa44f07 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -175,7 +175,7 @@ impl WordPiece { pub fn read_bytes(vocab: &[u8]) -> Result { let file = BufReader::new(vocab); - let mut vocab = HashMap::new(); + let mut vocab = AHashMap::new(); for (index, line) in file.lines().enumerate() { let line = line?; vocab.insert(line.trim_end().to_owned(), index as u32); diff --git a/tokenizers/tests/cli.rs b/tokenizers/tests/cli.rs new file mode 100644 index 000000000..39d340ade --- /dev/null +++ b/tokenizers/tests/cli.rs @@ -0,0 +1,101 @@ +use assert_cmd::Command; +use predicates::prelude::*; +use std::fs; +use std::path::Path; + +const BIN_NAME: &str = "tokenize"; + +#[test] +fn test_cli_tokenize_success() { + // Prepare a minimal model file (assume one exists for test) + let model_path = "./data/tokenizer.json"; + let text = "Hello world!"; + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("tokenize") + .arg("--model") + .arg(model_path) + .arg("--text") + .arg(text); + cmd.assert() + .success() + .stdout(predicate::str::contains("Token IDs:")); +} + +#[test] +fn test_cli_tokenize_missing_model() { + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("tokenize") + .arg("--model") + .arg("/nonexistent/model.json") + .arg("--text") + .arg("test"); + cmd.assert() + .failure() + .stderr(predicate::str::contains("Failed to load tokenizer model")); +} + +#[test] +fn test_cli_tokenize_invalid_text() { + // Should still succeed, but may return empty or error if model is bad + let model_path = "./data/tokenizer.json"; + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("tokenize") + .arg("--model") + .arg(model_path) + .arg("--text") + .arg(""); + cmd.assert() + .success() + .stdout(predicate::str::contains("Token IDs:")); +} + +#[test] +fn test_cli_train_success() { + // Prepare a small training file + let train_file = "./data/small.txt"; + let output_model = "./data/test-model.json"; + if Path::new(output_model).exists() { + fs::remove_file(output_model).unwrap(); + } + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("train") + .arg("--files") + .arg(train_file) + .arg("--vocab-size") + .arg("100") + .arg("--output") + .arg(output_model); + cmd.assert() + .success() + .stdout(predicate::str::contains("Model trained and saved to")); + assert!(Path::new(output_model).exists()); + fs::remove_file(output_model).unwrap(); +} + +#[test] +fn test_cli_train_missing_file() { + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("train") + .arg("--files") + .arg("/nonexistent/data.txt") + .arg("--output") + .arg("/tmp/should-not-exist.json"); + cmd.assert() + .failure() + .stderr(predicate::str::contains("Failed to train tokenizer")); +} + +#[test] +fn test_cli_train_invalid_output() { + // Output to a directory should fail + let train_file = "./data/small.txt"; + let mut cmd = Command::cargo_bin(BIN_NAME).unwrap(); + cmd.arg("train") + .arg("--files") + .arg(train_file) + .arg("--output") + .arg("./data/"); + cmd.assert() + .failure() + .stderr(predicate::str::contains("Failed to save trained model")); +}