Skip to content

Commit e6a4762

Browse files
authored
Merge pull request #135 from VectorInstitute/add_env_arg
Added additional launch command options: * --env option for adding environment variables * --config option for path to custom config
2 parents 07358cb + 0deaf5f commit e6a4762

File tree

13 files changed

+140
-22
lines changed

13 files changed

+140
-22
lines changed

tests/vec_inf/cli/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_launch_command_success(runner):
4040
"model_weights_parent_dir": "/model-weights",
4141
"vocab_size": "128000",
4242
"vllm_args": {"max_model_len": 8192},
43+
"env": {"CACHE": "/cache"},
4344
}
4445
mock_client.launch_model.return_value = mock_response
4546

tests/vec_inf/cli/test_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_format_table_output(self):
3636
"model_weights_parent_dir": "/model-weights",
3737
"log_dir": "/tmp/logs",
3838
"vllm_args": {"max_model_len": 8192, "enable_prefix_caching": True},
39+
"env": {"CACHE": "/cache"},
3940
}
4041

4142
formatter = LaunchResponseFormatter(model_name, params)
@@ -63,6 +64,7 @@ def test_format_table_output_with_minimal_params(self):
6364
"model_weights_parent_dir": "/weights",
6465
"log_dir": "/logs",
6566
"vllm_args": {},
67+
"env": {},
6668
}
6769

6870
formatter = LaunchResponseFormatter(model_name, params)

tests/vec_inf/client/test_helper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ def test_process_vllm_args(self, mock_path_exists, mock_load_config, mock_config
124124
assert vllm_args["--quantization"] == "awq"
125125
assert vllm_args["--compilation-config"] == "3"
126126

127+
@patch("vec_inf.client._helper.utils.load_config")
128+
@patch("pathlib.Path.exists")
129+
def test_process_env_vars(self, mock_path_exists, mock_load_config, mock_configs):
130+
"""Test that vars from `--env` flag are parsed correctly."""
131+
mock_load_config.return_value = mock_configs
132+
mock_path_exists.return_value = True
133+
134+
# Get filepath of dummy env file
135+
file_path = Path(__file__).parent / "test_vars.env"
136+
137+
launcher = ModelLauncher("test-model", {})
138+
env_vars = launcher._process_env_vars(f"CACHE_DIR=/cache,{file_path}")
139+
assert env_vars["CACHE_DIR"] == "/cache"
140+
assert env_vars["MY_VAR"] == "5"
141+
assert env_vars["VLLM_CACHE_ROOT"] == "/cache/vllm"
142+
127143
@patch("vec_inf.client._helper.utils.load_config")
128144
def test_get_launch_params_merges_config_and_cli_args(
129145
self, mock_load_config, model_config

tests/vec_inf/client/test_slurm_script_generator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def singularity_params(self, basic_params):
5555
{
5656
"venv": "singularity",
5757
"bind": "/scratch:/scratch,/data:/data",
58+
"env": {
59+
"CACHE_DIR": "/cache",
60+
"MY_VAR": "5",
61+
"VLLM_CACHE_ROOT": "/cache/vllm",
62+
},
5863
}
5964
)
6065
return singularity
@@ -74,6 +79,7 @@ def test_init_single_node(self, basic_params):
7479
assert not generator.use_container
7580
assert generator.additional_binds == ""
7681
assert generator.model_weights_path == "/path/to/model_weights/test-model"
82+
assert generator.env_str == ""
7783

7884
def test_init_multinode(self, multinode_params):
7985
"""Test initialization with multi-node configuration."""
@@ -84,6 +90,7 @@ def test_init_multinode(self, multinode_params):
8490
assert not generator.use_container
8591
assert generator.additional_binds == ""
8692
assert generator.model_weights_path == "/path/to/model_weights/test-model"
93+
assert generator.env_str == ""
8794

8895
def test_init_singularity(self, singularity_params):
8996
"""Test initialization with Singularity configuration."""
@@ -94,6 +101,10 @@ def test_init_singularity(self, singularity_params):
94101
assert not generator.is_multinode
95102
assert generator.additional_binds == " --bind /scratch:/scratch,/data:/data"
96103
assert generator.model_weights_path == "/path/to/model_weights/test-model"
104+
assert (
105+
generator.env_str
106+
== "--env CACHE_DIR=/cache,MY_VAR=5,VLLM_CACHE_ROOT=/cache/vllm"
107+
)
97108

98109
def test_init_singularity_no_bind(self, basic_params):
99110
"""Test Singularity initialization without additional binds."""
@@ -106,6 +117,7 @@ def test_init_singularity_no_bind(self, basic_params):
106117
assert not generator.is_multinode
107118
assert generator.additional_binds == ""
108119
assert generator.model_weights_path == "/path/to/model_weights/test-model"
120+
assert generator.env_str == ""
109121

110122
def test_generate_shebang_single_node(self, basic_params):
111123
"""Test shebang generation for single-node setup."""

tests/vec_inf/client/test_vars.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
MY_VAR=5
2+
VLLM_CACHE_ROOT=/cache/vllm

uv.lock

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vec_inf/cli/_cli.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@ def cli() -> None:
131131
is_flag=True,
132132
help="Output in JSON string",
133133
)
134+
@click.option(
135+
"--env",
136+
type=str,
137+
help="Environment variables to be set. Seperate variables with commas. Can also include path to a file containing environment variables seperated by newlines. e.g. --env 'TRITON_CACHE_DIR=/scratch/.cache/triton,my_custom_vars_file.env'",
138+
)
139+
@click.option(
140+
"--config",
141+
type=str,
142+
help="Path to a model config yaml file to use in place of the default",
143+
)
134144
def launch(
135145
model_name: str,
136146
**cli_kwargs: Optional[Union[str, int, float, bool]],
@@ -177,6 +187,10 @@ def launch(
177187
Path to model weights directory
178188
- vllm_args : str, optional
179189
vLLM engine arguments
190+
- env : str, optional
191+
Environment variables
192+
- config : str, optional
193+
Path to custom model config yaml file
180194
- json_mode : bool, optional
181195
Output in JSON format
182196
@@ -212,7 +226,10 @@ def launch(
212226
raise click.ClickException(f"Launch failed: {str(e)}") from e
213227

214228

215-
@cli.command("batch-launch", help="Launch multiple models in a batch.")
229+
@cli.command(
230+
"batch-launch",
231+
help="Launch multiple models in a batch, separate model names with spaces.",
232+
)
216233
@click.argument("model-names", type=str, nargs=-1)
217234
@click.option(
218235
"--batch-config",

vec_inf/cli/_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def format_table_output(self) -> Table:
8787
for arg, value in self.params["vllm_args"].items():
8888
table.add_row(f" {arg}:", str(value))
8989

90+
# Add Environment Variable Configuration Details
91+
table.add_row("Environment Variables", style="magenta")
92+
for arg, value in self.params["env"].items():
93+
table.add_row(f" {arg}:", str(value))
94+
9095
return table
9196

9297

vec_inf/client/_helper.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
6161
self.kwargs = kwargs or {}
6262
self.slurm_job_id = ""
6363
self.slurm_script_path = Path("")
64-
self.model_config = self._get_model_configuration()
64+
self.model_config = self._get_model_configuration(self.kwargs.get("config"))
6565
self.params = self._get_launch_params()
6666

6767
def _warn(self, message: str) -> None:
@@ -74,9 +74,14 @@ def _warn(self, message: str) -> None:
7474
"""
7575
warnings.warn(message, UserWarning, stacklevel=2)
7676

77-
def _get_model_configuration(self) -> ModelConfig:
77+
def _get_model_configuration(self, config_path: str | None = None) -> ModelConfig:
7878
"""Load and validate model configuration.
7979
80+
Parameters
81+
----------
82+
config_path : str | None, optional
83+
Path to a yaml file with custom model config to use in place of the default
84+
8085
Returns
8186
-------
8287
ModelConfig
@@ -89,7 +94,7 @@ def _get_model_configuration(self) -> ModelConfig:
8994
ModelConfigurationError
9095
If model configuration is not found and weights don't exist
9196
"""
92-
model_configs = utils.load_config()
97+
model_configs = utils.load_config(config_path=config_path)
9398
config = next(
9499
(m for m in model_configs if m.model_name == self.model_name), None
95100
)
@@ -158,6 +163,38 @@ def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
158163
vllm_args[arg.strip()] = True
159164
return vllm_args
160165

166+
def _process_env_vars(self, env_arg: str) -> dict[str, str]:
167+
"""Process the env string into a dictionary of environment variables.
168+
169+
Parameters
170+
----------
171+
env_arg : str
172+
String containing comma separated list of environment variable definitions
173+
(eg. MY_VAR=1), file paths containing environment variable definitions
174+
(separated by newlines), or a combination of both
175+
(eg. 'MY_VAR=5,my_env.env')
176+
177+
Returns
178+
-------
179+
dict[str, str]
180+
Processed environment variables as key-value pairs.
181+
"""
182+
env_vars: dict[str, str] = {}
183+
for arg in env_arg.split(","):
184+
if "=" in arg: # Arg is an env var definition
185+
key, value = arg.split("=")
186+
env_vars[key.strip()] = value.strip()
187+
else: # Arg is a path to a file
188+
with open(arg, "r") as file:
189+
lines = [line.rstrip() for line in file]
190+
for line in lines:
191+
if "=" in line:
192+
key, value = line.split("=")
193+
env_vars[key.strip()] = value.strip()
194+
else:
195+
print(f"WARNING: Could not parse env var: {line}")
196+
return env_vars
197+
161198
def _get_launch_params(self) -> dict[str, Any]:
162199
"""Prepare launch parameters, set log dir, and validate required fields.
163200
@@ -181,6 +218,12 @@ def _get_launch_params(self) -> dict[str, Any]:
181218
params["vllm_args"][key] = value
182219
del self.kwargs["vllm_args"]
183220

221+
if self.kwargs.get("env"):
222+
env_vars = self._process_env_vars(self.kwargs["env"])
223+
for key, value in env_vars.items():
224+
params["env"][key] = str(value)
225+
del self.kwargs["env"]
226+
184227
for key, value in self.kwargs.items():
185228
params[key] = value
186229

@@ -233,7 +276,7 @@ def _get_launch_params(self) -> dict[str, Any]:
233276

234277
# Convert path to string for JSON serialization
235278
for field in params:
236-
if field == "vllm_args":
279+
if field in ["vllm_args", "env"]:
237280
continue
238281
params[field] = str(params[field])
239282

vec_inf/client/_slurm_script_generator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def __init__(self, params: dict[str, Any]):
4141
self.model_weights_path = str(
4242
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
4343
)
44+
env_dict: dict[str, str] = self.params.get("env", {})
45+
# Create string of environment variables
46+
self.env_str = ""
47+
for key, val in env_dict.items():
48+
if len(self.env_str) == 0:
49+
self.env_str = "--env "
50+
else:
51+
self.env_str += ","
52+
self.env_str += key + "=" + val
4453

4554
def _generate_script_content(self) -> str:
4655
"""Generate the complete Slurm script content.
@@ -100,6 +109,7 @@ def _generate_server_setup(self) -> str:
100109
SLURM_SCRIPT_TEMPLATE["container_command"].format(
101110
model_weights_path=self.model_weights_path,
102111
additional_binds=self.additional_binds,
112+
env_str=self.env_str,
103113
),
104114
)
105115
else:
@@ -132,6 +142,7 @@ def _generate_launch_cmd(self) -> str:
132142
SLURM_SCRIPT_TEMPLATE["container_command"].format(
133143
model_weights_path=self.model_weights_path,
134144
additional_binds=self.additional_binds,
145+
env_str=self.env_str,
135146
)
136147
)
137148
else:

0 commit comments

Comments
 (0)