File size: 2,300 Bytes
5bd179e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse, json, math, os
from safetensors import safe_open
from safetensors.torch import save_file

parser = argparse.ArgumentParser(description = "Split .safetensors file into shards")
parser.add_argument("input_file", type = str, help = "Path to input file")
parser.add_argument("shard_size", type = int, help = "Shard size in megabytes")
args = parser.parse_args()

input_file = args.input_file
input_base, _ = os.path.splitext(input_file)
shard_size = args.shard_size * 1024**2

# Create tensor map

def _tsize(st, key):

    tslice = st.get_slice(key)
    shape = tslice.get_shape()
    numel = 1
    for x in shape: numel *= x
    dtype = tslice.get_dtype()
    del tslice
    if dtype == "I32": return numel * 4
    elif dtype == "I16": return numel * 2
    elif dtype == "F16": return numel * 2
    elif dtype == "F32": return numel * 4
    else: raise ValueError("Unexpected datatype: " + key)

num_files = 0
current_size = shard_size + 1
total_size = 0
tensor_map = []

print(f" -- Scanning tensors in {input_file}")

with safe_open(input_file, framework = "pt", device = "cpu") as f:
    
    for key in f.keys():
        
        tensor_size = _tsize(f, key)
        total_size += tensor_size
        
        if current_size + tensor_size > shard_size:
            
            num_files += 1
            current_size = 0
            current_list = []
            tensor_map.append(current_list)
            
        current_size += tensor_size
        current_list.append(key)

# Split into output files

weight_map = {}

for file_index, keys in enumerate(tensor_map):
    
    shard = {}
    shard_filename = f"{input_base}-{file_index + 1:05}-of-{num_files:05}.safetensors"

    with safe_open(input_file, framework = "pt", device = "cpu") as f:
        for key in keys:
            print(f" -- Reading: {key}")
            shard[key] = f.get_tensor(key)
            weight_map[key] = shard_filename

    print(f" -- Writing: {shard_filename}")
    save_file(shard, shard_filename)
    
# Compile index

index = { "metadata": { "total_size": total_size }, "weight_map": weight_map }
index_filename = f"{input_file}.index.json"

print(f" -- Writing: {index_filename}")

with open(index_filename, 'w') as f:
    json.dump(index, f, indent = 2)

# Done    
    
print(f" -- Done")