Skip to content

Commit 1db14f8

Browse files
committed
feat: update translation API to support generation parameters
- Modified request and response formats to include generation parameters for translation. - Updated documentation to reflect changes in API usage and added details about generation parameters.
1 parent 71012a2 commit 1db14f8

19 files changed

+173
-98
lines changed

docs/DOCKERHUB.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,32 @@ Available on [Docker Hub](https://hub.docker.com/r/ggwozdz/translation-api):
7474
- Request:
7575

7676
```bash
77-
curl -X POST "http://localhost:8000/translate" \
78-
-F "text=Hello, how are you?" \
79-
-F "source_language=en_US"
80-
-F "target_language=pl_PL"
77+
curl -X 'POST' \
78+
'http://127.0.0.1:8000/translate' \
79+
-H 'accept: application/json' \
80+
-H 'Content-Type: application/json' \
81+
-d '{
82+
"text_to_translate": "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
83+
"source_language": "en_US",
84+
"target_language": "pl_PL",
85+
"generation_parameters": { "max_length": 10240, "num_beams": 10 }
86+
}'
8187
```
8288

8389
- Response:
8490

8591
```json
8692
{
87-
"content": "Cześć, jak się masz?",
93+
"translation": "Wieża Eiffla ma wysokość 324 metrów, mniej więcej taką samą wysokość jak 81-piętrowy budynek, i jest najwyższą budowlą w Paryżu. Jego podstawa jest kwadratowa, mierząc 125 metrów na każdej stronie. Podczas jej budowy Wieża Eiffla przekroczyła Pomnik Waszyngtonu, stając się najwyższą budowlą stworzoną przez człowieka na świecie, tytuł utrzymywał przez 41 rok, dopóki budynek Chrysler w Nowym Jorku nie został ukończony w 1930 roku."
8894
}
8995
```
9096

97+
#### Generation parameters
98+
99+
The `generation_parameters` field in the request body allows you to specify the parameters which are described in the model documentation.
100+
101+
[For Seamless model](https://huggingface.co/docs/transformers/main/en/model_doc/seamless_m4t#transformers.SeamlessM4TForTextToText.generate) and [for mBART model](https://huggingface.co/docs/transformers/main/en/model_doc/mbart#transformers.MBartForConditionalGeneration.generate)
102+
91103
### Health Check
92104

93105
- Request:

docs/README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,32 @@ Choose your preferred distribution:
8989
- Request:
9090

9191
```bash
92-
curl -X POST "http://localhost:8000/translate" \
93-
-F "text=Hello, how are you?" \
94-
-F "source_language=en_US"
95-
-F "target_language=pl_PL"
92+
curl -X 'POST' \
93+
'http://127.0.0.1:8000/translate' \
94+
-H 'accept: application/json' \
95+
-H 'Content-Type: application/json' \
96+
-d '{
97+
"text_to_translate": "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
98+
"source_language": "en_US",
99+
"target_language": "pl_PL",
100+
"generation_parameters": { "max_length": 10240, "num_beams": 10 }
101+
}'
96102
```
97103

98104
- Response:
99105

100106
```json
101107
{
102-
"content": "Cześć, jak się masz?",
108+
"translation": "Wieża Eiffla ma wysokość 324 metrów, mniej więcej taką samą wysokość jak 81-piętrowy budynek, i jest najwyższą budowlą w Paryżu. Jego podstawa jest kwadratowa, mierząc 125 metrów na każdej stronie. Podczas jej budowy Wieża Eiffla przekroczyła Pomnik Waszyngtonu, stając się najwyższą budowlą stworzoną przez człowieka na świecie, tytuł utrzymywał przez 41 rok, dopóki budynek Chrysler w Nowym Jorku nie został ukończony w 1930 roku."
103109
}
104110
```
105111

112+
#### Generation parameters
113+
114+
The `generation_parameters` field in the request body allows you to specify the parameters which are described in the model documentation.
115+
116+
[For Seamless model](https://huggingface.co/docs/transformers/main/en/model_doc/seamless_m4t#transformers.SeamlessM4TForTextToText.generate) and [for mBART model](https://huggingface.co/docs/transformers/main/en/model_doc/mbart#transformers.MBartForConditionalGeneration.generate)
117+
106118
### Health Check
107119

108120
- Request:
@@ -156,6 +168,7 @@ Developer guide is available in [docs/DEVELOPER.md](DEVELOPER.md).
156168
- [Using Windows Executable](#using-windows-executable)
157169
- [API Features](#api-features)
158170
- [Translate Text](#translate-text)
171+
- [Generation parameters](#generation-parameters)
159172
- [Health Check](#health-check)
160173
- [Configuration](#configuration)
161174
- [Supported Languages](#supported-languages)

src/api/dtos/translate_dto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import re
2-
from typing import Optional
2+
from typing import Any, Dict, Optional
33

44
from pydantic import BaseModel, field_validator
55

66
from domain.exceptions.invalid_language_format_error import InvalidLanguageFormatError
77

88

99
class TranslateDTO(BaseModel):
10-
text: str
10+
text_to_translate: str
1111
source_language: str
1212
target_language: str
13+
generation_parameters: Dict[str, Any] = {}
1314

1415
@staticmethod
1516
def validate_language_format(v: str) -> str:

src/api/dtos/translate_result_dto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
class TranslateResultDTO(BaseModel):
5-
content: str
5+
translation: str

src/api/routers/translate_router.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated
22

3-
from fastapi import APIRouter, Depends
3+
from fastapi import APIRouter, Body, Depends
44

55
from api.dtos.translate_dto import TranslateDTO
66
from api.dtos.translate_result_dto import TranslateResultDTO
@@ -15,14 +15,15 @@ def __init__(self) -> None:
1515
async def translate(
1616
self,
1717
translate_text_usecase: Annotated[TranslateTextUseCase, Depends()],
18-
translate_dto: TranslateDTO = Depends(),
18+
translate_dto: TranslateDTO = Body(...),
1919
) -> TranslateResultDTO:
20-
result = await translate_text_usecase.execute(
21-
translate_dto.text,
20+
translation = await translate_text_usecase.execute(
21+
translate_dto.text_to_translate,
2222
translate_dto.source_language,
2323
translate_dto.target_language,
24+
translate_dto.generation_parameters,
2425
)
2526

2627
return TranslateResultDTO(
27-
content=result,
28+
translation=translation,
2829
)

src/application/usecases/translate_text_usecase.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated
1+
from typing import Annotated, Any, Dict
22

33
from fastapi import Depends
44

@@ -20,20 +20,22 @@ def __init__(
2020

2121
async def execute(
2222
self,
23-
text: str,
23+
text_to_translate: str,
2424
source_language: str,
2525
target_language: str,
26+
generation_parameters: Dict[str, Any],
2627
) -> str:
2728
self.logger.info(
28-
f"Executing translation for text '{text}' from '{source_language}' to '{target_language}'",
29+
f"Executing translation for text '{text_to_translate}' from '{source_language}' to '{target_language}'",
2930
)
3031

31-
translation_result: str = self.translation_service.translate_text(
32-
text,
32+
translation: str = self.translation_service.translate_text(
33+
text_to_translate,
3334
source_language,
3435
target_language,
36+
generation_parameters,
3537
)
3638

3739
self.logger.info("Returning translation result")
3840

39-
return translation_result
41+
return translation

src/data/repositories/translation_model_repository_impl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
import time
3-
from typing import Annotated, Optional
3+
from typing import Annotated, Any, Dict, Optional
44

55
from fastapi import Depends
66

@@ -59,9 +59,10 @@ def _check_idle_timeout(self) -> None:
5959

6060
def translate(
6161
self,
62-
text: str,
62+
text_to_translate: str,
6363
source_language: str,
6464
target_language: str,
65+
generation_parameters: Dict[str, Any],
6566
) -> str:
6667
with self._lock:
6768
if not self.worker.is_alive():
@@ -72,10 +73,11 @@ def translate(
7273
f"Translating started from source_language: {source_language}, target_language: {target_language}",
7374
)
7475

75-
result: str = self.worker.translate(
76-
text,
76+
translation: str = self.worker.translate(
77+
text_to_translate,
7778
source_language,
7879
target_language,
80+
generation_parameters,
7981
)
8082

8183
self.timer.start(
@@ -89,4 +91,4 @@ def translate(
8991
f"Translating completed from source_language: {source_language}, target_language: {target_language}",
9092
)
9193

92-
return result
94+
return translation

src/data/workers/mbart_translation_worker.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import multiprocessing.synchronize
44
from dataclasses import dataclass
55
from multiprocessing.sharedctypes import Synchronized
6-
from typing import Tuple
6+
from typing import Any, Dict, Tuple
77

8-
import torch
98
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
109

1110
from data.workers.base_worker import BaseWorker
@@ -22,22 +21,33 @@ class MBartTranslationConfig:
2221

2322
class MBartTranslationWorker(
2423
BaseWorker[ # type: ignore
25-
Tuple[str, str, str],
24+
Tuple[str, str, str, Dict[str, Any]],
2625
str,
2726
MBartTranslationConfig,
2827
Tuple[AutoModelForSeq2SeqLM, AutoTokenizer],
2928
],
3029
):
3130
def translate(
3231
self,
33-
text: str,
32+
text_to_translate: str,
3433
source_language: str,
3534
target_language: str,
35+
generation_parameters: Dict[str, Any],
3636
) -> str:
3737
if not self.is_alive():
3838
raise WorkerNotRunningError()
3939

40-
self._pipe_parent.send(("translate", (text, source_language, target_language)))
40+
self._pipe_parent.send(
41+
(
42+
"translate",
43+
(
44+
text_to_translate,
45+
source_language,
46+
target_language,
47+
generation_parameters,
48+
),
49+
),
50+
)
4151
result = self._pipe_parent.recv()
4252

4353
if isinstance(result, Exception):
@@ -62,7 +72,7 @@ def initialize_shared_object(
6272
def handle_command(
6373
self,
6474
command: str,
65-
args: Tuple[str, str, str],
75+
args: Tuple[str, str, str, Dict[str, Any]],
6676
shared_object: Tuple[AutoModelForSeq2SeqLM, AutoTokenizer],
6777
config: MBartTranslationConfig,
6878
pipe: multiprocessing.connection.Connection,
@@ -74,21 +84,20 @@ def handle_command(
7484
with processing_lock:
7585
is_processing.value = True
7686

77-
text, source_language, target_language = args
87+
text, source_language, target_language, generation_parameters = args
7888
model, tokenizer = shared_object
7989

8090
tokenizer.src_lang = source_language
81-
inputs = tokenizer([text], truncation=True, padding=True, max_length=1024, return_tensors="pt")
82-
83-
for key in inputs:
84-
inputs[key] = inputs[key].to(config.device)
91+
inputs = tokenizer(text, return_tensors="pt").to(config.device)
8592

86-
with torch.no_grad():
93+
if "forced_bos_token_id" in generation_parameters:
94+
kwargs = {"forced_bos_token_id": generation_parameters["forced_bos_token_id"]}
95+
else:
8796
kwargs = {"forced_bos_token_id": tokenizer.lang_code_to_id[target_language]}
8897

89-
translated = model.generate(**inputs, num_beams=5, **kwargs)
98+
translation = model.generate(**inputs, **kwargs)
9099

91-
output = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
100+
output = [tokenizer.decode(t, skip_special_tokens=True) for t in translation]
92101

93102
pipe.send("".join(output))
94103

src/data/workers/seamless_translation_worker.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import multiprocessing.synchronize
44
from dataclasses import dataclass
55
from multiprocessing.sharedctypes import Synchronized
6-
from typing import Tuple
6+
from typing import Any, Dict, Tuple
77

88
from transformers import AutoProcessor, SeamlessM4Tv2ForTextToText
99

@@ -21,22 +21,33 @@ class SeamlessTranslationConfig:
2121

2222
class SeamlessTranslationWorker(
2323
BaseWorker[ # type: ignore
24-
Tuple[str, str, str],
24+
Tuple[str, str, str, Dict[str, Any]],
2525
str,
2626
SeamlessTranslationConfig,
2727
Tuple[SeamlessM4Tv2ForTextToText, AutoProcessor],
2828
],
2929
):
3030
def translate(
3131
self,
32-
text: str,
32+
text_to_translate: str,
3333
source_language: str,
3434
target_language: str,
35+
generation_parameters: Dict[str, Any],
3536
) -> str:
3637
if not self.is_alive():
3738
raise WorkerNotRunningError()
3839

39-
self._pipe_parent.send(("translate", (text, source_language, target_language)))
40+
self._pipe_parent.send(
41+
(
42+
"translate",
43+
(
44+
text_to_translate,
45+
source_language,
46+
target_language,
47+
generation_parameters,
48+
),
49+
),
50+
)
4051
result = self._pipe_parent.recv()
4152

4253
if isinstance(result, Exception):
@@ -61,7 +72,7 @@ def initialize_shared_object(
6172
def handle_command(
6273
self,
6374
command: str,
64-
args: Tuple[str, str, str],
75+
args: Tuple[str, str, str, Dict[str, Any]],
6576
shared_object: Tuple[SeamlessM4Tv2ForTextToText, AutoProcessor],
6677
config: SeamlessTranslationConfig,
6778
pipe: multiprocessing.connection.Connection,
@@ -73,15 +84,26 @@ def handle_command(
7384
with processing_lock:
7485
is_processing.value = True
7586

76-
text, source_language, target_language = args
87+
text, source_language, target_language, generation_parameters = args
7788
model, processor = shared_object
7889

79-
processor.src_lang = source_language
80-
input_tokens = processor(text, return_tensors="pt", padding=True).to(config.device)
81-
82-
output_tokens = model.generate(**input_tokens, tgt_lang=target_language)[0].tolist()
83-
84-
text_output = processor.decode(output_tokens, skip_special_tokens=True)
90+
input_tokens = processor(
91+
text,
92+
src_lang=source_language,
93+
return_tensors="pt",
94+
padding=True,
95+
).to(config.device)
96+
97+
output_tokens = model.generate(
98+
**input_tokens,
99+
tgt_lang=target_language,
100+
**generation_parameters,
101+
)[0].tolist()
102+
103+
text_output = processor.decode(
104+
output_tokens,
105+
skip_special_tokens=True,
106+
)
85107

86108
pipe.send("".join(text_output))
87109

0 commit comments

Comments
 (0)