Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Commit 96cfc46

Browse files
authored
add UI (#42)
adds ability to easily play with large or even smaller models :)
1 parent dade1e2 commit 96cfc46

File tree

8 files changed

+378
-101
lines changed

8 files changed

+378
-101
lines changed

Dockerfile

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ FROM conda as conda_env
2626
# update conda
2727
RUN conda update -n base -c defaults conda -y
2828

29-
COPY Makefile Makefile
30-
COPY LICENSE LICENSE
31-
3229
# necessary stuff
3330
RUN pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
3431
transformers \
@@ -43,9 +40,12 @@ RUN pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.o
4340
grpcio-tools==1.50.0 \
4441
--no-cache-dir
4542

46-
# install grpc and compile protos
43+
# copy the code
4744
COPY inference_server inference_server
45+
COPY Makefile Makefile
46+
COPY LICENSE LICENSE
4847

48+
# install grpc and compile protos
4949
RUN make gen-proto
5050

5151
# clean conda env
@@ -58,9 +58,10 @@ ENV TRANSFORMERS_CACHE=/transformers_cache/ \
5858
HUGGINGFACE_HUB_CACHE=${TRANSFORMERS_CACHE} \
5959
HOME=/homedir
6060

61-
# Runs as arbitrary user in OpenShift
6261
RUN mkdir ${HOME} && chmod g+wx ${HOME} && \
6362
mkdir tmp && chmod -R g+w tmp
64-
# RUN chmod g+w Makefile
63+
64+
# for debugging
65+
# RUN chmod -R g+w inference_server && chmod g+w Makefile
6566

6667
CMD make bloom-176b

README.md

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,109 @@
11
# Fast Inference Solutions for BLOOM
22

3-
This repo provides demos and packages to perform fast inference solutions for BLOOM. Some of the solutions have their own repos in which case a link to the corresponding repos is provided instead.
3+
This repo provides demos and packages to perform fast inference solutions for BLOOM. Some of the solutions have their own repos in which case a link to the [corresponding repos](#Other-inference-solutions) is provided instead.
44

5-
Some of the solutions provide both half-precision and int8-quantized solution.
65

7-
## Client-side solutions
6+
# Inference solutions for BLOOM 176B
87

9-
Solutions developed to perform large batch inference locally:
8+
We support HuggingFace accelerate and DeepSpeed Inference for generation.
109

11-
Pytorch:
10+
Install required packages:
1211

13-
* [Accelerate, DeepSpeed-Inference and DeepSpeed-ZeRO](./bloom-inference-scripts)
12+
```shell
13+
pip install flask flask_api gunicorn pydantic accelerate huggingface_hub>=0.9.0 deepspeed>=0.7.3 deepspeed-mii==0.0.2
14+
```
1415

15-
* [Custom HF Code](https://github.com/huggingface/transformers_bloom_parallel/).
16+
alternatively you can also install deepspeed from source:
17+
```shell
18+
git clone https://github.com/microsoft/DeepSpeed
19+
cd DeepSpeed
20+
CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
21+
```
1622

17-
JAX:
23+
All the provided scripts are tested on 8 A100 80GB GPUs for BLOOM 176B (fp16/bf16) and 4 A100 80GB GPUs for BLOOM 176B (int8). These scripts might not work for other models or a different number of GPUs.
1824

19-
* [BLOOM Inference in JAX](https://github.com/huggingface/bloom-jax-inference)
25+
DS inference is deployed using logic borrowed from DeepSpeed MII library.
2026

27+
Note: Sometimes GPU memory is not freed when DS inference deployment crashes. You can free this memory by running `killall python` in terminal.
2128

29+
For using BLOOM quantized, use dtype = int8. Also, change the model_name to microsoft/bloom-deepspeed-inference-int8 for DeepSpeed-Inference. For HF accelerate, no change is needed for model_name.
2230

23-
## Server solutions
31+
HF accelerate uses [LLM.int8()](https://arxiv.org/abs/2208.07339) and DS-inference uses [ZeroQuant](https://arxiv.org/abs/2206.01861) for post-training quantization.
32+
33+
## BLOOM inference via command-line
34+
35+
This asks for generate_kwargs everytime.
36+
Example: generate_kwargs =
37+
```json
38+
{"min_length": 100, "max_new_tokens": 100, "do_sample": false}
39+
```
40+
41+
1. using HF accelerate
42+
```shell
43+
python -m inference_server.cli --model_name bigscience/bloom --model_class AutoModelForCausalLM --dtype bf16 --deployment_framework hf_accelerate --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}'
44+
```
45+
46+
2. using DS inference
47+
```shell
48+
python -m inference_server.cli --model_name microsoft/bloom-deepspeed-inference-fp16 --model_class AutoModelForCausalLM --dtype fp16 --deployment_framework ds_inference --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}'
49+
```
50+
51+
## BLOOM server deployment
52+
53+
[make <model_name>](../Makefile) can be used to launch a generation server. Please note that the serving method is synchronous and users have to wait in queue until the preceding requests have been processed. An example to fire server requests is given [here](../server_request.py). Alternativey, a [Dockerfile](./Dockerfile) is also provided which launches a generation server on port 5000.
54+
55+
An interactive UI can be launched via the following command to connect to the generation server. The default URL of the UI is `http://127.0.0.1:5001/`.
56+
```shell
57+
python -m ui
58+
```
59+
This command launches the following UI to play with generation. Sorry for the crappy design. Unfotunately, my UI skills only go so far. 😅😅😅
60+
![image](assets/UI.png)
61+
62+
## Benchmark system for BLOOM inference
63+
64+
1. using HF accelerate
65+
```shell
66+
python -m inference_server.benchmark --model_name bigscience/bloom --model_class AutoModelForCausalLM --dtype bf16 --deployment_framework hf_accelerate --benchmark_cycles 5
67+
```
68+
69+
2. using DS inference
70+
```shell
71+
deepspeed --num_gpus 8 --module inference_server.benchmark --model_name bigscience/bloom --model_class AutoModelForCausalLM --dtype fp16 --deployment_framework ds_inference --benchmark_cycles 5
72+
```
73+
alternatively, to load model faster:
74+
```shell
75+
deepspeed --num_gpus 8 --module inference_server.benchmark --model_name microsoft/bloom-deepspeed-inference-fp16 --model_class AutoModelForCausalLM --dtype fp16 --deployment_framework ds_inference --benchmark_cycles 5
76+
```
2477

25-
Solutions developed to be used in a server mode (i.e. varied batch size, varied request rate):
78+
3. using DS ZeRO
79+
```shell
80+
deepspeed --num_gpus 8 --module inference_server.benchmark --model_name bigscience/bloom --model_class AutoModelForCausalLM --dtype bf16 --deployment_framework ds_zero --benchmark_cycles 5
81+
```
2682

27-
Pytorch:
83+
# Support
2884

29-
* [Accelerate and DeepSpeed-Inference based solutions](./bloom-inference-server)
3085

31-
Rust:
86+
If you run into things not working or have other questions please open an Issue in the corresponding backend:
87+
88+
- [Accelerate](https://github.com/huggingface/accelerate/issues)
89+
- [Deepspeed-Inference](https://github.com/microsoft/DeepSpeed/issues)
90+
- [Deepspeed-ZeRO](https://github.com/microsoft/DeepSpeed/issues)
91+
92+
If there a specific issue with one of the scripts and not the backend only then please open an Issue here and tag [@mayank31398](https://github.com/mayank31398).
93+
94+
95+
# Other inference solutions
96+
## Client-side solutions
97+
98+
Solutions developed to perform large batch inference locally:
99+
100+
* [Custom HF Code](https://github.com/huggingface/transformers_bloom_parallel/).
101+
102+
JAX:
103+
104+
* [BLOOM Inference in JAX](https://github.com/huggingface/bloom-jax-inference)
105+
106+
107+
## Server solutions
32108

33-
* [Bloom-server](https://github.com/Narsil/bloomserver)
109+
A solution developed to be used in a server mode (i.e. varied batch size, varied request rate) can be found [here](https://github.com/Narsil/bloomserver). This is implemented in Rust.

assets/UI.png

212 KB
Loading

inference_server/README.md

Lines changed: 0 additions & 80 deletions
This file was deleted.

static/css/style.css

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#left-column {
2+
width: 80%;
3+
}
4+
5+
#right-column {
6+
width: 18%;
7+
float: right;
8+
padding-right: 10px;
9+
}
10+
11+
body {
12+
background-color: lightgray;
13+
height: auto;
14+
}
15+
16+
#text-input {
17+
width: 100%;
18+
float: left;
19+
resize: none;
20+
}
21+
22+
.slider {
23+
width: 100%;
24+
float: left;
25+
}
26+
27+
#log-output {
28+
width: 100%;
29+
float: left;
30+
resize: none;
31+
}
32+
33+
#max-new-tokens-input {
34+
width: 30%;
35+
float: left;
36+
margin-left: 5px;
37+
}

static/js/index.js

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
const textGenInput = document.getElementById('text-input');
2+
const clickButton = document.getElementById('submit-button');
3+
4+
const temperatureSlider = document.getElementById('temperature-slider');
5+
const temperatureTextBox = document.getElementById('temperature-textbox')
6+
7+
const top_pSlider = document.getElementById('top_p-slider');
8+
const top_pTextBox = document.getElementById('top_p-textbox');
9+
10+
const top_kSlider = document.getElementById('top_k-slider');
11+
const top_kTextBox = document.getElementById('top_k-textbox');
12+
13+
const repetition_penaltySlider = document.getElementById('repetition_penalty-slider');
14+
const repetition_penaltyTextBox = document.getElementById('repetition_penalty-textbox');
15+
16+
const max_new_tokensInput = document.getElementById('max-new-tokens-input');
17+
18+
const textLogOutput = document.getElementById('log-output');
19+
20+
function get_temperature() {
21+
return parseFloat(temperatureSlider.value);
22+
}
23+
24+
temperatureSlider.addEventListener('input', async (event) => {
25+
temperatureTextBox.innerHTML = "temperature = " + get_temperature();
26+
});
27+
28+
function get_top_p() {
29+
return parseFloat(top_pSlider.value);
30+
}
31+
32+
top_pSlider.addEventListener('input', async (event) => {
33+
top_pTextBox.innerHTML = "top_p = " + get_top_p();
34+
});
35+
36+
function get_top_k() {
37+
return parseInt(top_kSlider.value);
38+
}
39+
40+
top_kSlider.addEventListener('input', async (event) => {
41+
top_kTextBox.innerHTML = "top_k = " + get_top_k();
42+
});
43+
44+
function get_repetition_penalty() {
45+
return parseFloat(repetition_penaltySlider.value);
46+
}
47+
48+
repetition_penaltySlider.addEventListener('input', async (event) => {
49+
repetition_penaltyTextBox.innerHTML = "repetition_penalty = " + get_repetition_penalty();
50+
});
51+
52+
function get_max_new_tokens() {
53+
return parseInt(max_new_tokensInput.value);
54+
}
55+
56+
clickButton.addEventListener('click', async (event) => {
57+
clickButton.textContent = 'Processing'
58+
clickButton.disabled = true;
59+
60+
var jsonPayload = {
61+
text: [textGenInput.value],
62+
temperature: get_temperature(),
63+
top_k: get_top_k(),
64+
top_p: get_top_p(),
65+
max_new_tokens: get_max_new_tokens(),
66+
repetition_penalty: get_repetition_penalty(),
67+
do_sample: true,
68+
remove_input_from_output: true
69+
};
70+
71+
if (jsonPayload.temperature == 0) {
72+
jsonPayload.do_sample = false;
73+
}
74+
75+
console.log(jsonPayload);
76+
77+
$.ajax({
78+
url: '/generate/',
79+
type: 'POST',
80+
contentType: "application/json; charset=utf-8",
81+
data: JSON.stringify(jsonPayload),
82+
headers: { 'Access-Control-Allow-Origin': '*' },
83+
success: function (response) {
84+
var input_text = textGenInput.value;
85+
86+
if ("text" in response) {
87+
textGenInput.value = input_text + response.text[0];
88+
89+
textLogOutput.value = 'total_time_taken = ' + response.total_time_taken + "\n";
90+
textLogOutput.value += 'num_generated_tokens = ' + response.num_generated_tokens + "\n";
91+
textLogOutput.style.backgroundColor = "lightblue";
92+
} else {
93+
textLogOutput.value = 'total_time_taken = ' + response.total_time_taken + "\n";
94+
textLogOutput.value += 'error: ' + response.message;
95+
textLogOutput.style.backgroundColor = "#D65235";
96+
}
97+
98+
clickButton.textContent = 'Submit';
99+
clickButton.disabled = false;
100+
},
101+
error: function (error) {
102+
console.log(JSON.stringify(error, null, 2));
103+
clickButton.textContent = 'Submit'
104+
clickButton.disabled = false;
105+
}
106+
});
107+
});

0 commit comments

Comments
 (0)