-
Notifications
You must be signed in to change notification settings - Fork 9
/
git-theta-filter
executable file
·168 lines (142 loc) · 5.71 KB
/
git-theta-filter
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python
import argparse
import sys
import logging
import numpy as np
from git_theta import (
git_utils,
checkpoints,
params,
metadata,
updates,
async_utils,
lsh,
)
from git_theta.utils import EnvVarConstants
logging.basicConfig(
level=logging.DEBUG,
# Log to a file for clean/smudge as they don't appear on the console when called via git.
filename="/tmp/git-theta.log",
format="git-theta-filter: [%(asctime)s] [%(funcName)s] %(levelname)s - %(message)s",
)
def parse_args():
parser = argparse.ArgumentParser(description="git-theta filter program")
subparsers = parser.add_subparsers(title="Commands", dest="command")
subparsers.required = True
clean_parser = subparsers.add_parser("clean", help="clean filter")
clean_parser.add_argument("file", help="file being passed to clean filter")
clean_parser.set_defaults(func=clean)
smudge_parser = subparsers.add_parser("smudge", help="smudge filter")
smudge_parser.add_argument("file", help="file being passed to smudge filter")
smudge_parser.set_defaults(func=smudge)
args = parser.parse_args()
return args
def clean(args):
"""
Implements clean filter for model files
Metadata file looks as follows:
{
"model/scoping/to/param/1-weight": {
"tensor_metadata": {
"shape": List[str],
"dtype": str,
"hash": str,
},
},
...,
"model/scoping/to/param/2-bias": {
"tensor_metadata": {
"shape": List[str],
"dtype": str,
"hash": str,
},
},
...,
}
"""
logging.debug(f"Running clean filter on {args.file}")
repo = git_utils.get_git_repo()
checkpoint_handler = checkpoints.get_checkpoint_handler()
model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer)
# Note: If the update serializer is configurable per-parameter, it will need to be created inside _clean
update_serializer = params.get_update_serializer()
prev_metadata = metadata.Metadata.from_commit(repo, args.file, "HEAD").flatten()
async def _clean(param_keys, new_param):
logging.debug(f"Cleaning {'/'.join(param_keys)}")
param_metadata = prev_metadata.get(param_keys)
new_tensor_metadata = metadata.TensorMetadata.from_tensor(new_param)
# If the parameter tensor has not changed, just keep the metadata the same
if (
param_metadata
and param_metadata.tensor_metadata.shape == new_tensor_metadata.shape
and param_metadata.tensor_metadata.dtype == new_tensor_metadata.dtype
):
hasher = lsh.get_lsh()
hash_distance = hasher.distance(
param_metadata.tensor_metadata.hash, new_tensor_metadata.hash
)
# If hash_distance < PARAMETER_ATOL, assume the tensors pass np.allclose and parameter hasn't changed
if hash_distance < EnvVarConstants.PARAMETER_ATOL:
return param_keys, param_metadata
# If PARAMETER_ATOL < hash_distance < LSH_THRESHOLD, load parameters and check if parameter has changed with np.allclose
elif hash_distance < EnvVarConstants.LSH_THRESHOLD:
update_handler = updates.get_update_handler(
param_metadata.theta_metadata.update_type
)(update_serializer)
param = await update_handler.apply(
param_metadata, param_keys, repo=repo, path=args.file
)
if np.allclose(
param,
new_param,
rtol=EnvVarConstants.PARAMETER_RTOL,
atol=EnvVarConstants.PARAMETER_ATOL,
):
return param_keys, param_metadata
update_handler = updates.get_update_handler()(update_serializer)
new_theta_metadata = metadata.ThetaMetadata(
update_type=update_handler.name, last_commit=git_utils.get_head(repo)
)
lfs_metadata = await update_handler.write(
new_param,
param_keys,
prev_metadata=param_metadata,
repo=repo,
path=args.file,
)
new_param_metadata = metadata.ParamMetadata(
lfs_metadata=lfs_metadata,
tensor_metadata=new_tensor_metadata,
theta_metadata=new_theta_metadata,
)
return param_keys, new_param_metadata
# Sort the keys so we don't get changing diffs based on serialization order.
sorted_checkpoint = dict(sorted(model_checkpoint.flatten().items()))
new_metadata = metadata.Metadata(
**async_utils.run(async_utils.run_map(sorted_checkpoint, _clean))
)
new_metadata.unflatten().write(sys.stdout)
def smudge(args):
"""
Implements smudge filter for model files
"""
logging.debug(f"Running smudge filter on {args.file}")
repo = git_utils.get_git_repo()
curr_metadata = metadata.Metadata.from_file(sys.stdin).flatten()
async def _smudge(param_keys, param_metadata):
logging.debug(f"Smudging {'/'.join(param_keys)}")
update_handler = updates.get_update_handler(
param_metadata.theta_metadata.update_type
)(params.get_update_serializer())
param_value = await update_handler.apply(
param_metadata, param_keys, repo=repo, path=args.file
)
return param_keys, param_value
model_dict = async_utils.run(async_utils.run_map(curr_metadata, _smudge))
checkpoint_handler = checkpoints.get_checkpoint_handler()
model_checkpoint = checkpoint_handler(model_dict).unflatten()
model_checkpoint.save(sys.stdout.buffer)
if __name__ == "__main__":
args = parse_args()
git_utils.set_hooks()
args.func(args)