LRhinehart's picture
Upload folder using huggingface_hub
5bd179e
raw
history blame contribute delete
No virus
2.3 kB
import argparse, json, math, os
from safetensors import safe_open
from safetensors.torch import save_file
parser = argparse.ArgumentParser(description = "Split .safetensors file into shards")
parser.add_argument("input_file", type = str, help = "Path to input file")
parser.add_argument("shard_size", type = int, help = "Shard size in megabytes")
args = parser.parse_args()
input_file = args.input_file
input_base, _ = os.path.splitext(input_file)
shard_size = args.shard_size * 1024**2
# Create tensor map
def _tsize(st, key):
tslice = st.get_slice(key)
shape = tslice.get_shape()
numel = 1
for x in shape: numel *= x
dtype = tslice.get_dtype()
del tslice
if dtype == "I32": return numel * 4
elif dtype == "I16": return numel * 2
elif dtype == "F16": return numel * 2
elif dtype == "F32": return numel * 4
else: raise ValueError("Unexpected datatype: " + key)
num_files = 0
current_size = shard_size + 1
total_size = 0
tensor_map = []
print(f" -- Scanning tensors in {input_file}")
with safe_open(input_file, framework = "pt", device = "cpu") as f:
for key in f.keys():
tensor_size = _tsize(f, key)
total_size += tensor_size
if current_size + tensor_size > shard_size:
num_files += 1
current_size = 0
current_list = []
tensor_map.append(current_list)
current_size += tensor_size
current_list.append(key)
# Split into output files
weight_map = {}
for file_index, keys in enumerate(tensor_map):
shard = {}
shard_filename = f"{input_base}-{file_index + 1:05}-of-{num_files:05}.safetensors"
with safe_open(input_file, framework = "pt", device = "cpu") as f:
for key in keys:
print(f" -- Reading: {key}")
shard[key] = f.get_tensor(key)
weight_map[key] = shard_filename
print(f" -- Writing: {shard_filename}")
save_file(shard, shard_filename)
# Compile index
index = { "metadata": { "total_size": total_size }, "weight_map": weight_map }
index_filename = f"{input_file}.index.json"
print(f" -- Writing: {index_filename}")
with open(index_filename, 'w') as f:
json.dump(index, f, indent = 2)
# Done
print(f" -- Done")