Spaces:
Build error
Build error
import argparse, json, math, os | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
parser = argparse.ArgumentParser(description = "Split .safetensors file into shards") | |
parser.add_argument("input_file", type = str, help = "Path to input file") | |
parser.add_argument("shard_size", type = int, help = "Shard size in megabytes") | |
args = parser.parse_args() | |
input_file = args.input_file | |
input_base, _ = os.path.splitext(input_file) | |
shard_size = args.shard_size * 1024**2 | |
# Create tensor map | |
def _tsize(st, key): | |
tslice = st.get_slice(key) | |
shape = tslice.get_shape() | |
numel = 1 | |
for x in shape: numel *= x | |
dtype = tslice.get_dtype() | |
del tslice | |
if dtype == "I32": return numel * 4 | |
elif dtype == "I16": return numel * 2 | |
elif dtype == "F16": return numel * 2 | |
elif dtype == "F32": return numel * 4 | |
else: raise ValueError("Unexpected datatype: " + key) | |
num_files = 0 | |
current_size = shard_size + 1 | |
total_size = 0 | |
tensor_map = [] | |
print(f" -- Scanning tensors in {input_file}") | |
with safe_open(input_file, framework = "pt", device = "cpu") as f: | |
for key in f.keys(): | |
tensor_size = _tsize(f, key) | |
total_size += tensor_size | |
if current_size + tensor_size > shard_size: | |
num_files += 1 | |
current_size = 0 | |
current_list = [] | |
tensor_map.append(current_list) | |
current_size += tensor_size | |
current_list.append(key) | |
# Split into output files | |
weight_map = {} | |
for file_index, keys in enumerate(tensor_map): | |
shard = {} | |
shard_filename = f"{input_base}-{file_index + 1:05}-of-{num_files:05}.safetensors" | |
with safe_open(input_file, framework = "pt", device = "cpu") as f: | |
for key in keys: | |
print(f" -- Reading: {key}") | |
shard[key] = f.get_tensor(key) | |
weight_map[key] = shard_filename | |
print(f" -- Writing: {shard_filename}") | |
save_file(shard, shard_filename) | |
# Compile index | |
index = { "metadata": { "total_size": total_size }, "weight_map": weight_map } | |
index_filename = f"{input_file}.index.json" | |
print(f" -- Writing: {index_filename}") | |
with open(index_filename, 'w') as f: | |
json.dump(index, f, indent = 2) | |
# Done | |
print(f" -- Done") |