-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Text2sql tool - e2e evals and fine-tuning #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jeffxtang
wants to merge
80
commits into
main
Choose a base branch
from
text2sql
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
80 commits
Select commit
Hold shift + click to select a range
ebed0ef
text2sql eval and ft tools
edcf746
script and README update
jeffxtang 76a8caf
script to create reasoning dataset; llama_text2sql.py and requirement…
jeffxtang ab7df10
train loss png
jeffxtang 2e8b278
README update
jeffxtang 2c514b1
README update
jeffxtang 5a18b6b
README cleanup
jeffxtang 0033fc9
README update
jeffxtang 3997357
README update - overview and quick start
jeffxtang 44aa896
README update - next steps and creating reasoning dataset
jeffxtang 46d3245
README and create_reasoning_dataset.py and trl_sft.py update
jeffxtang b89d945
readme and requirements updated based on PR feedback
jeffxtang 3731175
readme update based on PR feeedback
jeffxtang 094ab01
readme
jeffxtang 6d76ea0
updated create_reasoning_dataset, llama_text2sql.py and README for fi…
jeffxtang cf54eb4
README format fix
jeffxtang e182902
train loss cot; README
jeffxtang 79945b6
parent README update
jeffxtang 0aa42d8
adding text2sql colab notebook for fine-tuning
ghd3v 4cdd5f6
README and requirements.txt update
jeffxtang 5e8a7b0
remove needless words in README
jeffxtang 71ca0ae
README update
jeffxtang b17f90b
README update
jeffxtang 6815255
folder struc refactoring
jeffxtang 03ba7d5
quickstart folder
jeffxtang 11a4a64
quickstart readme and colab link
jeffxtang 99ead57
4 READMEs; requirements
jeffxtang ee1fc97
png fix
jeffxtang 9c294df
eval/llama_eval.sh data path
jeffxtang ef6bbb2
main README and FT README update
jeffxtang caf98ec
unified script supporting 6 FT configs - quantized or not, peft/fft, …
jeffxtang 4107171
Merge branch 'text2sql' of https://github.com/meta-llama/llama-cookbo…
jeffxtang b02334a
updated fine-tuning README
jeffxtang f7c68c1
bug fix on loading PEFT and FFT model for eval
jeffxtang 4037737
READMEs update based on the new FT results etc
jeffxtang f07da72
finetuning README
jeffxtang 57ffb74
finetuning README
jeffxtang a6f7d02
train loss update
jeffxtang 7a4ae9f
ft readme typo fix
jeffxtang 9ac5dd1
make max_tokens same - 10240
jeffxtang 6b92409
dynamically calculating max tokens param; it was unused before
ghd3v c4573ba
fixing github we viewing
ghd3v 7b508ec
testing fix of github notebook viewing
ghd3v 3c23112
fixing github web rendering of the notebook
ghd3v cc93b73
rm unsloth quantization notebook
ghd3v 6269c15
fine-tuning README update with latest result
jeffxtang 2cdfbf0
READMEs update
jeffxtang 4bb7faa
READMEs update
jeffxtang 2bd662c
adding vllm eval files and updating requirements.txt
ghd3v b574c6d
Updated README.md
ghd3v 58ea6cb
Update README.md
ghd3v 33ac1ab
Update README.md
ghd3v e10ddda
some refactoring and cleaning
ghd3v 5baa1e3
vllm enabled eval for HF and fine-tuned models; code cleanup and refa…
jeffxtang f894d26
vllm enabled eval for HF and fine-tuned models; code cleanup and refa…
jeffxtang ad48509
trl import
ghd3v e059899
Update the eval section using vllm for fine-tuning README.md
jeffxtang f80e7bf
Update fine-tuning README.md
jeffxtang b630735
Update fine-tuning README.md
jeffxtang 1ac67d9
Update eval README.md for vllm based HF model
jeffxtang 1b802d3
Update FT README.md
jeffxtang df598c4
Update eval README.md
jeffxtang cb8b0bd
Update eval README.md
jeffxtang 77d3544
batch processing and vllm llama call in parallel; clean progress show…
jeffxtang deca42c
Merge branch 'text2sql' of https://github.com/meta-llama/llama-cookbo…
jeffxtang 12a6dfa
code cleanup and refactoring; cloud llama response generation in tqdm…
jeffxtang 799dee6
some cleanup and typo fix
jeffxtang 6501cf4
FT readme update; removed old vllm py and sh files
jeffxtang 82bb008
3 READMEs update; fine-tuning requirements update with vllm etc
jeffxtang 27a23af
main README
jeffxtang e38abf1
Update eval README.md
jeffxtang fc80546
Update FT EADME.md
jeffxtang be4817c
Update FT README.md with llama 3.3 70b on multiple gpus
jeffxtang af3ea4f
script to save dev set in pandas csv format
jeffxtang 0c7b348
restored llama_eval.sh
jeffxtang 54e49bc
added steps to run create_bird_eval_dataset.py
jeffxtang c88e10f
grpo llama 3.2 3b with 3 reward functions
jeffxtang 57c0517
added llm as a judge reward func
jeffxtang 7edf3d8
llm as a judge running now
jeffxtang 8989e69
README for grpo
jeffxtang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,39 @@ | ||
## Text2SQL: Natural Language to SQL Interface | ||
# Improving Llama Text2SQL performance with CoT Fine-tuning | ||
|
||
This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data. | ||
This recipe is step by step guide to improve Llama performance on Text2SQL measured with the popular [BIRD](https://bird-bench.github.io) benchmark. We generate a synthetic Chain of Thought(CoT) dataset and fine-tune Llama models on it. | ||
|
||
For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook. | ||
Results: | ||
|
||
### Structure: | ||
| Fine-tuning Combination | Accuracy | | ||
|-----------------------------|-------------------------------| | ||
| baseline | 39.47% | | ||
| CoT, PEFT | 43.35% | | ||
| CoT, FFT | 42.44% (3 epochs) | | ||
| CoT, FFT | 43.87% (10 epochs) | | ||
|
||
- quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. | ||
- nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. | ||
- txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. | ||
- csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. | ||
- nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. | ||
The complete steps are: | ||
|
||
### Detailed steps on running the notebook: | ||
1. Pre-processing the [BIRD](https://bird-bench.github.io) TRAIN datset by converting text, schema, external knowledge, and SQL statements into the conversation format. | ||
|
||
- Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/). | ||
2. Using Llama-3.3-70B to add CoT to the conversation format dataset. | ||
|
||
- First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder: | ||
3. Fine-tuning Llama-3.1-8B on the CoT dataset from step 2. | ||
|
||
``` | ||
git clone https://github.com/meta-llama/llama-cookbook.git | ||
cd llama-cookbook/end-to-end-use-cases/coding/text2sql/ | ||
pip install -r requirements.txt | ||
``` | ||
4. Running the BIRD DEV eval benchmark on the fine-tuned models and compare it with out of the model. | ||
|
||
### Contributing | ||
Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas: | ||
- Adding support for additional databases | ||
- Developing new interfaces or applications that use the Text2SQL interface | ||
## Folder Structure | ||
|
||
- quickstart folder: contains a notebook to ask Llama 3.3 to convert natural language queries into SQL queries. | ||
- data folder: contains scripts to download the BIRD TRAIN and DEV datasets; | ||
- fine-tune folder: contains scripts to generate CoT dataset based on the BIRD TRAIN set and to supervised fine-tune Llama models using the dataset, with different SFT options (quantization or not, full fine-tuning or parameter-efficient fine-tuning); | ||
- eval folder: contains scripts to evaluate Llama models (original and fine-tuned) on the BIRD dataset. | ||
|
||
We also experimented with supervised fine-tuning (SFT) without CoT which resulted in slightly lower accuracy. | ||
|
||
## Next Steps | ||
|
||
1. Hyper-parameter tuning of the current SFT scripts. | ||
2. Try GRPO reinforcement learning to further improve the accuracy. | ||
3. Fine-tune Llama 3.3 70B and Llama 4 models. | ||
4. Try agentic workflow. | ||
5. Expand the eval to support other enterprise databases. | ||
9 changes: 9 additions & 0 deletions
9
end-to-end-use-cases/coding/text2sql/data/download_dev_unzip.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip | ||
unzip dev.zip | ||
rm dev.zip | ||
rm -rf __MACOSX | ||
cd dev_20240627 | ||
unzip dev_databases.zip | ||
rm dev_databases.zip | ||
rm -rf __MACOSX | ||
cd .. |
9 changes: 9 additions & 0 deletions
9
end-to-end-use-cases/coding/text2sql/data/download_train_unzip.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip | ||
UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip | ||
rm train.zip | ||
rm -rf __MACOSX | ||
cd train | ||
unzip train_databases.zip | ||
rm train_databases.zip | ||
rm -rf __MACOSX | ||
cd .. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Llama Text2SQL Evaluation | ||
|
||
We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com), as well as Llama 3.1 8B on Hugging Face and its fine-tuned models. | ||
|
||
## Evaluation Results | ||
|
||
Below are the results of the Llama models we have evaluated on the BIRD DEV dataset: | ||
|
||
| Model | Llama API Accuracy | | ||
|------------------------|--------------------| | ||
| Llama 3.1 8b | 39.47% (*) | | ||
| Llama 3.3 70b | 54.11% | | ||
| Llama 4 Scout | 44.39% | | ||
| Llama 4 Maverick | 44.00% | | ||
|
||
- Since Llama API does not have Llama 3.1 8b model, we use Hugging Face weights and vllm to run locally. | ||
|
||
## Quick Start with Llama Models via Llama API | ||
|
||
Follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark: | ||
|
||
1. Run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation: | ||
|
||
``` | ||
conda create -n llama-text2sql python=3.10 | ||
conda activate llama-text2sql | ||
git clone https://github.com/meta-llama/llama-cookbook | ||
git checkout text2sql # to be removed after the PR merge | ||
cd llama-cookbook/end-to-end-use-cases/coding/text2sql/eval | ||
pip install -r requirements.txt | ||
``` | ||
|
||
2. Get the DEV dataset: | ||
``` | ||
cd ../data | ||
sh download_dev_unzip.sh | ||
cd ../eval | ||
``` | ||
|
||
3. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval. | ||
|
||
4. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL. | ||
|
||
If your API key or model name is incorrect, the script will exit with an authentication or model not supported error. | ||
|
||
After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'` | ||
|
||
To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/). | ||
|
||
## Evaluation with Llama Models on Hugging Face or Fine-tuned | ||
|
||
We use vllm OpenAI compatible server to run Llama 3.1 8B on Hugging Face (steps below) or its fine-tuned models (steps [here](../fine-tuning/#evaluating-the-fine-tuned-model) for eval: | ||
|
||
1. Uncomment the last two lines in requirements.txt then run `pip install -r requirements.txt`: | ||
``` | ||
# vllm==0.9.2 | ||
# openai==1.90.0 | ||
``` | ||
|
||
2. Uncomment in `llama_eval.sh`: | ||
``` | ||
YOUR_API_KEY='huggingface' | ||
model='meta-llama/Llama-3.1-8B-Instruct' | ||
``` | ||
|
||
3. Start the vllm server: | ||
``` | ||
vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64 | ||
``` | ||
or if you want to speed up the inference and eval and have multiple GPUs, you can set `--tensor-parallel-size` to the number of your available GPUs, e.g.: | ||
``` | ||
vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64 | ||
``` | ||
|
||
then run `sh llama_eval.sh`. | ||
|
||
## Evaluation Process | ||
|
||
1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries. | ||
|
||
2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below. | ||
|
||
3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L29)) with the results from the ground truth SQL to determine correctness. | ||
|
||
4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging). |
161 changes: 161 additions & 0 deletions
161
end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import argparse | ||
import json | ||
import os | ||
import sqlite3 | ||
|
||
import pandas as pd | ||
|
||
# from datasets import Dataset | ||
from tqdm import tqdm | ||
|
||
|
||
def new_directory(path): | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
|
||
def nice_look_table(column_names: list, values: list): | ||
rows = [] | ||
# Determine the maximum width of each column | ||
widths = [ | ||
max(len(str(value[i])) for value in values + [column_names]) | ||
for i in range(len(column_names)) | ||
] | ||
|
||
# Print the column names | ||
header = "".join( | ||
f"{column.rjust(width)} " for column, width in zip(column_names, widths) | ||
) | ||
# print(header) | ||
# Print the values | ||
for value in values: | ||
row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths)) | ||
rows.append(row) | ||
rows = "\n".join(rows) | ||
final_output = header + "\n" + rows | ||
return final_output | ||
|
||
|
||
def generate_schema_prompt(db_path, num_rows=None): | ||
# extract create ddls | ||
""" | ||
:param root_place: | ||
:param db_name: | ||
:return: | ||
""" | ||
full_schema_prompt_list = [] | ||
conn = sqlite3.connect(db_path) | ||
# Create a cursor object | ||
cursor = conn.cursor() | ||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") | ||
tables = cursor.fetchall() | ||
schemas = {} | ||
for table in tables: | ||
if table == "sqlite_sequence": | ||
continue | ||
cursor.execute( | ||
"SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( | ||
table[0] | ||
) | ||
) | ||
create_prompt = cursor.fetchone()[0] | ||
schemas[table[0]] = create_prompt | ||
if num_rows: | ||
cur_table = table[0] | ||
if cur_table in ["order", "by", "group"]: | ||
cur_table = "`{}`".format(cur_table) | ||
|
||
cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) | ||
column_names = [description[0] for description in cursor.description] | ||
values = cursor.fetchall() | ||
rows_prompt = nice_look_table(column_names=column_names, values=values) | ||
verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( | ||
num_rows, cur_table, num_rows, rows_prompt | ||
) | ||
schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) | ||
|
||
for k, v in schemas.items(): | ||
full_schema_prompt_list.append(v) | ||
|
||
schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) | ||
|
||
return schema_prompt | ||
|
||
|
||
def generate_comment_prompt(question, knowledge=None): | ||
knowledge_prompt = "-- External Knowledge: {}".format(knowledge) | ||
question_prompt = "-- Question: {}".format(question) | ||
|
||
result_prompt = knowledge_prompt + "\n\n" + question_prompt | ||
|
||
return result_prompt | ||
|
||
|
||
def generate_combined_prompts_one(db_path, question, knowledge=None): | ||
schema_prompt = generate_schema_prompt(db_path, num_rows=None) | ||
comment_prompt = generate_comment_prompt(question, knowledge) | ||
|
||
combined_prompts = schema_prompt + "\n\n" + comment_prompt | ||
|
||
return combined_prompts | ||
|
||
|
||
def create_conversation(sample): | ||
return { | ||
"messages": [ | ||
{"role": "system", "content": sample["messages"][0]["content"]}, | ||
{"role": "user", "content": sample["messages"][1]["content"]}, | ||
{"role": "assistant", "content": sample["messages"][2]["content"]}, | ||
] | ||
} | ||
|
||
|
||
def create_bird_eval_dataset(input_json, db_root_path): | ||
SYSTEM_PROMPT = ( | ||
"You are a text to SQL query translator. Using the SQLite DB Schema and the " | ||
"External Knowledge, translate the following text question into a SQLite SQL " | ||
"select statement." | ||
) | ||
data = [] | ||
|
||
for i, item in tqdm(enumerate(input_json)): | ||
print(f"processing #{i+1}") | ||
db_id = item["db_id"] | ||
question = item["question"] | ||
external_knowledge = item["evidence"] | ||
SQL = item["SQL"] | ||
db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite" | ||
print(f"{db_path=}") | ||
prompt = generate_combined_prompts_one( | ||
db_path, | ||
question, | ||
knowledge=external_knowledge, | ||
) | ||
|
||
data.append( | ||
{ | ||
"prompt": SYSTEM_PROMPT + "\n\n" + prompt, | ||
"gold_sql": SQL, | ||
"db_id": db_id, | ||
} | ||
) | ||
|
||
df = pd.DataFrame(data) | ||
df.to_csv("bird_dev_set_eval.csv", index=False) | ||
print(f"Dataset saved as bird_dev_set_eval.csv with {len(df)} rows") | ||
|
||
|
||
if __name__ == "__main__": | ||
args_parser = argparse.ArgumentParser() | ||
args_parser.add_argument("--input_json", type=str, required=True) | ||
args_parser.add_argument("--db_root_path", type=str, required=True) | ||
args = args_parser.parse_args() | ||
|
||
input_json = json.load(open(args.input_json, "r")) | ||
db_root_path = args.db_root_path | ||
|
||
create_bird_eval_dataset(input_json, db_root_path) | ||
|
||
# follow steps 1 and 2 here https://github.com/meta-llama/llama-cookbook/tree/text2sql/end-to-end-use-cases/coding/text2sql/eval#quick-start-with-llama-models-via-llama-api | ||
# then run: | ||
# python3 create_bird_eval_dataset.py --input_json ../data/dev_20240627/dev.json --db_root_path ../data/dev_20240627/dev_databases |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Set to "true" to enable debug mode with detailed prints | ||
DEBUG_MODE="false" | ||
|
||
eval_path='../data/dev_20240627/dev.json' | ||
db_root_path='../data/dev_20240627/dev_databases/' | ||
ground_truth_path='../data/' | ||
|
||
# Llama models on Llama API | ||
# YOUR_API_KEY='YOUR_LLAMA_API_KEY' | ||
# model='Llama-3.3-8B-Instruct' | ||
#model='Llama-3.3-70B-Instruct' | ||
#model='Llama-4-Maverick-17B-128E-Instruct-FP8' | ||
#model='Llama-4-Scout-17B-16E-Instruct-FP8' | ||
|
||
# Llama model on Hugging Face Hub https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct | ||
# YOUR_API_KEY='huggingface' | ||
# model='meta-llama/Llama-3.1-8B-Instruct' | ||
|
||
# Fine-tuned Llama models locally | ||
YOUR_API_KEY='finetuned' | ||
model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3' | ||
|
||
data_output_path="./output/$model/" | ||
|
||
echo "Text2SQL using $model" | ||
python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \ | ||
--model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path} | ||
|
||
# Check if llama_text2sql.py exited successfully | ||
if [ $? -eq 0 ]; then | ||
echo "llama_text2sql.py completed successfully. Proceeding with evaluation..." | ||
|
||
# Add --debug flag if DEBUG_MODE is true | ||
if [ "$DEBUG_MODE" = "true" ]; then | ||
python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ | ||
--ground_truth_path ${ground_truth_path} \ | ||
--diff_json_path ${eval_path} --debug | ||
else | ||
python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ | ||
--ground_truth_path ${ground_truth_path} \ | ||
--diff_json_path ${eval_path} | ||
fi | ||
|
||
echo "Done evaluating $model." | ||
|
||
else | ||
echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation." | ||
exit 1 | ||
fi |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be dataset vs database?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the "database" - to expand the eval to go beyond the current sqlite and include Oracle, etc.