Skip to content

Commit 0292cb8

Browse files
authored
modify unflatten for vllm (#3297)
1 parent 8496b55 commit 0292cb8

File tree

3 files changed

+88
-10
lines changed

3 files changed

+88
-10
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ def test_safetensors(self, config, act_pre_scale=False):
7474

7575
save_file(tensors_data_dict, f.name, metadata=metadata)
7676
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
77-
reconstructed_dict = unflatten_tensor_state_dict(
77+
reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict(
7878
tensors_data_dict, metadata
7979
)
80+
assert not leftover_tensor_data_dict
8081

8182
model = torch.nn.Sequential(
8283
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
@@ -85,6 +86,47 @@ def test_safetensors(self, config, act_pre_scale=False):
8586
output = model(*example_inputs)
8687
assert torch.equal(output, ref_output)
8788

89+
@parametrize(
90+
"config, act_pre_scale",
91+
[
92+
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False),
93+
(Int4WeightOnlyConfig(), False),
94+
(Int4WeightOnlyConfig(), True),
95+
(Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False),
96+
(IntxWeightOnlyConfig(), False),
97+
(Int8DynamicActivationIntxWeightConfig(), False),
98+
],
99+
)
100+
def test_safetensors_sharded(self, config, act_pre_scale=False):
101+
model = torch.nn.Sequential(
102+
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
103+
)
104+
quantize_(model, config)
105+
if act_pre_scale:
106+
model[0].weight.act_pre_scale = torch.ones(
107+
(1), dtype=torch.bfloat16, device="cuda"
108+
)
109+
110+
with tempfile.NamedTemporaryFile() as f:
111+
tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict())
112+
save_file(tensors_data_dict, f.name, metadata=metadata)
113+
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
114+
115+
# simulate missing info on future file
116+
if act_pre_scale:
117+
del tensors_data_dict["0._weight_act_pre_scale"] # optional tensor data
118+
else:
119+
del tensors_data_dict["0._weight_qdata"]
120+
121+
reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict(
122+
tensors_data_dict, metadata
123+
)
124+
125+
# since qdata is missing, layer 0 should not have been processed
126+
for key in tensors_data_dict.keys():
127+
if key.startswith("0._weight_"):
128+
assert key in leftover_tensor_data_dict
129+
88130

89131
instantiate_parametrized_tests(TestSafeTensors)
90132

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
3434
'_data': {
3535
'block_size': [1,32],
3636
...
37-
}
37+
},
38+
'_tensor_data_names': ['qdata', 'scale']
3839
}
3940
'0.bias': {
4041
'_type': 'torch.Tensor',
@@ -66,33 +67,51 @@ def unflatten_tensor_state_dict(
6667

6768
tensor_names = json.loads(metadata["tensor_names"])
6869
result = {}
69-
70+
leftover_state_dict = tensors_data_dict.copy()
7071
for tensor_name in tensor_names:
72+
processed_tensors = []
73+
7174
module_fqn, weight_name = tensor_name.rsplit(".", 1)
7275

7376
prefix = f"{module_fqn}._{weight_name}_"
7477
tensor_tensors = {}
78+
7579
for key, value in combined_data.items():
7680
if key.startswith(prefix):
7781
# Remove the prefix
7882
tensor_tensors[key[len(prefix) :]] = value
7983

8084
tensor_metadata = json.loads(metadata.get(tensor_name))
8185
tensor_type = tensor_metadata.get("_type")
86+
complete_tensor_data_names = tensor_metadata.get("_tensor_data_names")
8287

8388
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
84-
if not tensor_tensors:
85-
# we allow the option of loading in state_dict info for a single tensor
86-
# if tensor state dict info is not loaded in yet, we wait for it to be provided
87-
# in a future call
89+
# if not all tensor data is present (ie missing qdata) we wait for it
90+
# to be loaded in from a future call
91+
if not len(tensor_tensors) is len(complete_tensor_data_names):
8892
continue
8993
tensor_metadata["_data"].update(tensor_tensors)
9094
result[tensor_name] = object_from_dict(tensor_metadata)
95+
96+
for suffix in complete_tensor_data_names:
97+
processed_tensors.append(prefix + suffix)
9198
elif tensor_type == torch.Tensor.__name__:
99+
# we allow the option of loading in state_dict info for a single tensor
100+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
101+
# in a future call
102+
if tensor_name not in tensors_data_dict.keys():
103+
continue
92104
result[tensor_name] = tensors_data_dict[tensor_name]
105+
processed_tensors.append(
106+
tensor_name
107+
) # add here because key for torch.Tensor has no prefix
93108
else:
94109
raise ValueError(f"Unsupported tensor type: {tensor_type}")
95-
return result
110+
111+
for tensor_name in processed_tensors:
112+
del leftover_state_dict[tensor_name]
113+
114+
return result, leftover_state_dict
96115

97116

98117
def flatten_tensor_state_dict(
@@ -125,7 +144,8 @@ def flatten_tensor_state_dict(
125144
'_data': {
126145
'block_size': [1,32],
127146
...
128-
}
147+
},
148+
'_tensor_data_names': ['qdata', 'scale']
129149
}
130150
'0.bias': {
131151
'_type': 'torch.Tensor',

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,23 @@ def default(self, o):
6060
encoded_attribute = self.encode_value(attribute)
6161
tensor_attr_dict[tensor_attribute_name] = encoded_attribute
6262

63-
return {"_type": o.__class__.__name__, "_data": tensor_attr_dict}
63+
optional_tensor_data_names = (
64+
o.optional_tensor_data_names
65+
if hasattr(o, "optional_tensor_data_names")
66+
else []
67+
)
68+
all_tensor_data_names = optional_tensor_data_names + o.tensor_data_names
69+
70+
_tensor_data_names = []
71+
for tensor_data_name in all_tensor_data_names:
72+
if getattr(o, tensor_data_name) is not None:
73+
_tensor_data_names.append(tensor_data_name)
74+
75+
return {
76+
"_type": o.__class__.__name__,
77+
"_data": tensor_attr_dict,
78+
"_tensor_data_names": _tensor_data_names,
79+
}
6480

6581
if hasattr(o, "_fields") and hasattr(
6682
o, "_asdict"

0 commit comments

Comments
 (0)