Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
ebed0ef
text2sql eval and ft tools
Jun 25, 2025
edcf746
script and README update
jeffxtang Jun 25, 2025
76a8caf
script to create reasoning dataset; llama_text2sql.py and requirement…
jeffxtang Jun 26, 2025
ab7df10
train loss png
jeffxtang Jun 26, 2025
2e8b278
README update
jeffxtang Jun 26, 2025
2c514b1
README update
jeffxtang Jun 26, 2025
5a18b6b
README cleanup
jeffxtang Jun 26, 2025
0033fc9
README update
jeffxtang Jun 26, 2025
3997357
README update - overview and quick start
jeffxtang Jun 26, 2025
44aa896
README update - next steps and creating reasoning dataset
jeffxtang Jun 27, 2025
46d3245
README and create_reasoning_dataset.py and trl_sft.py update
jeffxtang Jun 27, 2025
b89d945
readme and requirements updated based on PR feedback
jeffxtang Jun 27, 2025
3731175
readme update based on PR feeedback
jeffxtang Jun 27, 2025
094ab01
readme
jeffxtang Jun 27, 2025
6d76ea0
updated create_reasoning_dataset, llama_text2sql.py and README for fi…
jeffxtang Jun 28, 2025
cf54eb4
README format fix
jeffxtang Jun 28, 2025
e182902
train loss cot; README
jeffxtang Jun 28, 2025
79945b6
parent README update
jeffxtang Jun 28, 2025
0aa42d8
adding text2sql colab notebook for fine-tuning
ghd3v Jul 1, 2025
4cdd5f6
README and requirements.txt update
jeffxtang Jul 1, 2025
5e8a7b0
remove needless words in README
jeffxtang Jul 1, 2025
71ca0ae
README update
jeffxtang Jul 1, 2025
b17f90b
README update
jeffxtang Jul 1, 2025
6815255
folder struc refactoring
jeffxtang Jul 2, 2025
03ba7d5
quickstart folder
jeffxtang Jul 2, 2025
11a4a64
quickstart readme and colab link
jeffxtang Jul 2, 2025
99ead57
4 READMEs; requirements
jeffxtang Jul 2, 2025
ee1fc97
png fix
jeffxtang Jul 2, 2025
9c294df
eval/llama_eval.sh data path
jeffxtang Jul 3, 2025
ef6bbb2
main README and FT README update
jeffxtang Jul 8, 2025
caf98ec
unified script supporting 6 FT configs - quantized or not, peft/fft, …
jeffxtang Jul 10, 2025
4107171
Merge branch 'text2sql' of https://github.com/meta-llama/llama-cookbo…
jeffxtang Jul 10, 2025
b02334a
updated fine-tuning README
jeffxtang Jul 10, 2025
f7c68c1
bug fix on loading PEFT and FFT model for eval
jeffxtang Jul 10, 2025
4037737
READMEs update based on the new FT results etc
jeffxtang Jul 11, 2025
f07da72
finetuning README
jeffxtang Jul 11, 2025
57ffb74
finetuning README
jeffxtang Jul 11, 2025
a6f7d02
train loss update
jeffxtang Jul 11, 2025
7a4ae9f
ft readme typo fix
jeffxtang Jul 12, 2025
9ac5dd1
make max_tokens same - 10240
jeffxtang Jul 12, 2025
6b92409
dynamically calculating max tokens param; it was unused before
ghd3v Jul 12, 2025
c4573ba
fixing github we viewing
ghd3v Jul 12, 2025
7b508ec
testing fix of github notebook viewing
ghd3v Jul 12, 2025
3c23112
fixing github web rendering of the notebook
ghd3v Jul 12, 2025
cc93b73
rm unsloth quantization notebook
ghd3v Jul 12, 2025
6269c15
fine-tuning README update with latest result
jeffxtang Jul 12, 2025
2cdfbf0
READMEs update
jeffxtang Jul 14, 2025
4bb7faa
READMEs update
jeffxtang Jul 14, 2025
2bd662c
adding vllm eval files and updating requirements.txt
ghd3v Jul 18, 2025
b574c6d
Updated README.md
ghd3v Jul 22, 2025
58ea6cb
Update README.md
ghd3v Jul 22, 2025
33ac1ab
Update README.md
ghd3v Jul 22, 2025
e10ddda
some refactoring and cleaning
ghd3v Jul 23, 2025
5baa1e3
vllm enabled eval for HF and fine-tuned models; code cleanup and refa…
jeffxtang Jul 23, 2025
f894d26
vllm enabled eval for HF and fine-tuned models; code cleanup and refa…
jeffxtang Jul 23, 2025
ad48509
trl import
ghd3v Jul 23, 2025
e059899
Update the eval section using vllm for fine-tuning README.md
jeffxtang Jul 23, 2025
f80e7bf
Update fine-tuning README.md
jeffxtang Jul 23, 2025
b630735
Update fine-tuning README.md
jeffxtang Jul 23, 2025
1ac67d9
Update eval README.md for vllm based HF model
jeffxtang Jul 23, 2025
1b802d3
Update FT README.md
jeffxtang Jul 23, 2025
df598c4
Update eval README.md
jeffxtang Jul 23, 2025
cb8b0bd
Update eval README.md
jeffxtang Jul 23, 2025
77d3544
batch processing and vllm llama call in parallel; clean progress show…
jeffxtang Jul 24, 2025
deca42c
Merge branch 'text2sql' of https://github.com/meta-llama/llama-cookbo…
jeffxtang Jul 24, 2025
12a6dfa
code cleanup and refactoring; cloud llama response generation in tqdm…
jeffxtang Jul 24, 2025
799dee6
some cleanup and typo fix
jeffxtang Jul 25, 2025
6501cf4
FT readme update; removed old vllm py and sh files
jeffxtang Jul 25, 2025
82bb008
3 READMEs update; fine-tuning requirements update with vllm etc
jeffxtang Jul 28, 2025
27a23af
main README
jeffxtang Jul 28, 2025
e38abf1
Update eval README.md
jeffxtang Aug 1, 2025
fc80546
Update FT EADME.md
jeffxtang Aug 1, 2025
be4817c
Update FT README.md with llama 3.3 70b on multiple gpus
jeffxtang Aug 13, 2025
af3ea4f
script to save dev set in pandas csv format
jeffxtang Aug 22, 2025
0c7b348
restored llama_eval.sh
jeffxtang Aug 22, 2025
54e49bc
added steps to run create_bird_eval_dataset.py
jeffxtang Aug 22, 2025
c88e10f
grpo llama 3.2 3b with 3 reward functions
jeffxtang Oct 7, 2025
57c0517
added llm as a judge reward func
jeffxtang Oct 8, 2025
7edf3d8
llm as a judge running now
jeffxtang Oct 9, 2025
8989e69
README for grpo
jeffxtang Oct 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions end-to-end-use-cases/coding/text2sql/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
## Text2SQL: Natural Language to SQL Interface
# Improving Llama Text2SQL performance with CoT Fine-tuning

This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data.
This recipe is step by step guide to improve Llama performance on Text2SQL measured with the popular [BIRD](https://bird-bench.github.io) benchmark. We generate a synthetic Chain of Thought(CoT) dataset and fine-tune Llama models on it.

For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook.
Results:

### Structure:
| Fine-tuning Combination | Accuracy |
|-----------------------------|-------------------------------|
| baseline | 39.47% |
| CoT, PEFT | 43.35% |
| CoT, FFT | 42.44% (3 epochs) |
| CoT, FFT | 43.87% (10 epochs) |

- quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai.
- nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes.
- txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py.
- csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data.
- nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface.
The complete steps are:

### Detailed steps on running the notebook:
1. Pre-processing the [BIRD](https://bird-bench.github.io) TRAIN datset by converting text, schema, external knowledge, and SQL statements into the conversation format.

- Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/).
2. Using Llama-3.3-70B to add CoT to the conversation format dataset.

- First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder:
3. Fine-tuning Llama-3.1-8B on the CoT dataset from step 2.

```
git clone https://github.com/meta-llama/llama-cookbook.git
cd llama-cookbook/end-to-end-use-cases/coding/text2sql/
pip install -r requirements.txt
```
4. Running the BIRD DEV eval benchmark on the fine-tuned models and compare it with out of the model.

### Contributing
Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas:
- Adding support for additional databases
- Developing new interfaces or applications that use the Text2SQL interface
## Folder Structure

- quickstart folder: contains a notebook to ask Llama 3.3 to convert natural language queries into SQL queries.
- data folder: contains scripts to download the BIRD TRAIN and DEV datasets;
- fine-tune folder: contains scripts to generate CoT dataset based on the BIRD TRAIN set and to supervised fine-tune Llama models using the dataset, with different SFT options (quantization or not, full fine-tuning or parameter-efficient fine-tuning);
- eval folder: contains scripts to evaluate Llama models (original and fine-tuned) on the BIRD dataset.

We also experimented with supervised fine-tuning (SFT) without CoT which resulted in slightly lower accuracy.

## Next Steps

1. Hyper-parameter tuning of the current SFT scripts.
2. Try GRPO reinforcement learning to further improve the accuracy.
3. Fine-tune Llama 3.3 70B and Llama 4 models.
4. Try agentic workflow.
5. Expand the eval to support other enterprise databases.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be dataset vs database?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the "database" - to expand the eval to go beyond the current sqlite and include Oracle, etc.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip
unzip dev.zip
rm dev.zip
rm -rf __MACOSX
cd dev_20240627
unzip dev_databases.zip
rm dev_databases.zip
rm -rf __MACOSX
cd ..
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip
UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip
rm train.zip
rm -rf __MACOSX
cd train
unzip train_databases.zip
rm train_databases.zip
rm -rf __MACOSX
cd ..
85 changes: 85 additions & 0 deletions end-to-end-use-cases/coding/text2sql/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Llama Text2SQL Evaluation

We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com), as well as Llama 3.1 8B on Hugging Face and its fine-tuned models.

## Evaluation Results

Below are the results of the Llama models we have evaluated on the BIRD DEV dataset:

| Model | Llama API Accuracy |
|------------------------|--------------------|
| Llama 3.1 8b | 39.47% (*) |
| Llama 3.3 70b | 54.11% |
| Llama 4 Scout | 44.39% |
| Llama 4 Maverick | 44.00% |

- Since Llama API does not have Llama 3.1 8b model, we use Hugging Face weights and vllm to run locally.

## Quick Start with Llama Models via Llama API

Follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark:

1. Run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation:

```
conda create -n llama-text2sql python=3.10
conda activate llama-text2sql
git clone https://github.com/meta-llama/llama-cookbook
git checkout text2sql # to be removed after the PR merge
cd llama-cookbook/end-to-end-use-cases/coding/text2sql/eval
pip install -r requirements.txt
```

2. Get the DEV dataset:
```
cd ../data
sh download_dev_unzip.sh
cd ../eval
```

3. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval.

4. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL.

If your API key or model name is incorrect, the script will exit with an authentication or model not supported error.

After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'`

To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).

## Evaluation with Llama Models on Hugging Face or Fine-tuned

We use vllm OpenAI compatible server to run Llama 3.1 8B on Hugging Face (steps below) or its fine-tuned models (steps [here](../fine-tuning/#evaluating-the-fine-tuned-model) for eval:

1. Uncomment the last two lines in requirements.txt then run `pip install -r requirements.txt`:
```
# vllm==0.9.2
# openai==1.90.0
```

2. Uncomment in `llama_eval.sh`:
```
YOUR_API_KEY='huggingface'
model='meta-llama/Llama-3.1-8B-Instruct'
```

3. Start the vllm server:
```
vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64
```
or if you want to speed up the inference and eval and have multiple GPUs, you can set `--tensor-parallel-size` to the number of your available GPUs, e.g.:
```
vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64
```

then run `sh llama_eval.sh`.

## Evaluation Process

1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.

2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below.

3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L29)) with the results from the ground truth SQL to determine correctness.

4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging).
161 changes: 161 additions & 0 deletions end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import json
import os
import sqlite3

import pandas as pd

# from datasets import Dataset
from tqdm import tqdm


def new_directory(path):
if not os.path.exists(path):
os.makedirs(path)


def nice_look_table(column_names: list, values: list):
rows = []
# Determine the maximum width of each column
widths = [
max(len(str(value[i])) for value in values + [column_names])
for i in range(len(column_names))
]

# Print the column names
header = "".join(
f"{column.rjust(width)} " for column, width in zip(column_names, widths)
)
# print(header)
# Print the values
for value in values:
row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
rows.append(row)
rows = "\n".join(rows)
final_output = header + "\n" + rows
return final_output


def generate_schema_prompt(db_path, num_rows=None):
# extract create ddls
"""
:param root_place:
:param db_name:
:return:
"""
full_schema_prompt_list = []
conn = sqlite3.connect(db_path)
# Create a cursor object
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
schemas = {}
for table in tables:
if table == "sqlite_sequence":
continue
cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
table[0]
)
)
create_prompt = cursor.fetchone()[0]
schemas[table[0]] = create_prompt
if num_rows:
cur_table = table[0]
if cur_table in ["order", "by", "group"]:
cur_table = "`{}`".format(cur_table)

cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
column_names = [description[0] for description in cursor.description]
values = cursor.fetchall()
rows_prompt = nice_look_table(column_names=column_names, values=values)
verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
num_rows, cur_table, num_rows, rows_prompt
)
schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)

for k, v in schemas.items():
full_schema_prompt_list.append(v)

schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)

return schema_prompt


def generate_comment_prompt(question, knowledge=None):
knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
question_prompt = "-- Question: {}".format(question)

result_prompt = knowledge_prompt + "\n\n" + question_prompt

return result_prompt


def generate_combined_prompts_one(db_path, question, knowledge=None):
schema_prompt = generate_schema_prompt(db_path, num_rows=None)
comment_prompt = generate_comment_prompt(question, knowledge)

combined_prompts = schema_prompt + "\n\n" + comment_prompt

return combined_prompts


def create_conversation(sample):
return {
"messages": [
{"role": "system", "content": sample["messages"][0]["content"]},
{"role": "user", "content": sample["messages"][1]["content"]},
{"role": "assistant", "content": sample["messages"][2]["content"]},
]
}


def create_bird_eval_dataset(input_json, db_root_path):
SYSTEM_PROMPT = (
"You are a text to SQL query translator. Using the SQLite DB Schema and the "
"External Knowledge, translate the following text question into a SQLite SQL "
"select statement."
)
data = []

for i, item in tqdm(enumerate(input_json)):
print(f"processing #{i+1}")
db_id = item["db_id"]
question = item["question"]
external_knowledge = item["evidence"]
SQL = item["SQL"]
db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
print(f"{db_path=}")
prompt = generate_combined_prompts_one(
db_path,
question,
knowledge=external_knowledge,
)

data.append(
{
"prompt": SYSTEM_PROMPT + "\n\n" + prompt,
"gold_sql": SQL,
"db_id": db_id,
}
)

df = pd.DataFrame(data)
df.to_csv("bird_dev_set_eval.csv", index=False)
print(f"Dataset saved as bird_dev_set_eval.csv with {len(df)} rows")


if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument("--input_json", type=str, required=True)
args_parser.add_argument("--db_root_path", type=str, required=True)
args = args_parser.parse_args()

input_json = json.load(open(args.input_json, "r"))
db_root_path = args.db_root_path

create_bird_eval_dataset(input_json, db_root_path)

# follow steps 1 and 2 here https://github.com/meta-llama/llama-cookbook/tree/text2sql/end-to-end-use-cases/coding/text2sql/eval#quick-start-with-llama-models-via-llama-api
# then run:
# python3 create_bird_eval_dataset.py --input_json ../data/dev_20240627/dev.json --db_root_path ../data/dev_20240627/dev_databases
49 changes: 49 additions & 0 deletions end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Set to "true" to enable debug mode with detailed prints
DEBUG_MODE="false"

eval_path='../data/dev_20240627/dev.json'
db_root_path='../data/dev_20240627/dev_databases/'
ground_truth_path='../data/'

# Llama models on Llama API
# YOUR_API_KEY='YOUR_LLAMA_API_KEY'
# model='Llama-3.3-8B-Instruct'
#model='Llama-3.3-70B-Instruct'
#model='Llama-4-Maverick-17B-128E-Instruct-FP8'
#model='Llama-4-Scout-17B-16E-Instruct-FP8'

# Llama model on Hugging Face Hub https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
# YOUR_API_KEY='huggingface'
# model='meta-llama/Llama-3.1-8B-Instruct'

# Fine-tuned Llama models locally
YOUR_API_KEY='finetuned'
model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3'

data_output_path="./output/$model/"

echo "Text2SQL using $model"
python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \
--model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}

# Check if llama_text2sql.py exited successfully
if [ $? -eq 0 ]; then
echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."

# Add --debug flag if DEBUG_MODE is true
if [ "$DEBUG_MODE" = "true" ]; then
python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
--ground_truth_path ${ground_truth_path} \
--diff_json_path ${eval_path} --debug
else
python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
--ground_truth_path ${ground_truth_path} \
--diff_json_path ${eval_path}
fi

echo "Done evaluating $model."

else
echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation."
exit 1
fi
Loading
Loading