Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add iter time tracking via cuda events, add data loading times, add columnar display to show both, show avg iter & data loading times at end of training #87

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class TrainState:
step: int = 0
current_loss: float = -1
losses: List[float] = field(default_factory=list)
iter_times: List[float] = field(default_factory=list)
data_load_times: List[float] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -177,15 +179,22 @@ def main(job_config: JobConfig):
):
train_state.step += 1
# get batch
data_load_start = timer()
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
data_load_time = round(timer() - data_load_start, 4)
train_state.data_load_times.append(data_load_time)
nwords_since_last_log += labels.numel()

optimizer.zero_grad()

# forward
start_timer = torch.cuda.Event(enable_timing=True)
end_timer = torch.cuda.Event(enable_timing=True)
start_timer.record()

pred = model(input_ids)
tok_loss = F.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1), reduction="none"
Expand All @@ -207,6 +216,13 @@ def main(job_config: JobConfig):
# updates the scale for next iteration
scaler.update()

# training iteration complete
end_timer.record()
torch.cuda.synchronize()

curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
train_state.iter_times.append(curr_iter_time)

# if profiler is active
if torch_profiler:
torch_profiler.step()
Expand Down Expand Up @@ -251,8 +267,8 @@ def main(job_config: JobConfig):
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {round(train_state.current_loss,4)},"
f" lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
)
scheduler.step()

Expand All @@ -261,6 +277,13 @@ def main(job_config: JobConfig):
)

metric_logger.close()
# calc and show average iter time, disregard first three iterations (warmup)
if len(train_state.iter_times) > 3:
avg_iter_time = np.mean(train_state.iter_times[3:])
rank0_log(f"Average iter time: {avg_iter_time:.4f} seconds")
avg_data_load_time = np.mean(train_state.data_load_times[3:])
rank0_log(f"Average data load time: {avg_data_load_time:.4f} seconds")

rank0_log(f"{gpu_metrics.get_current_stats()}")


Expand Down
Loading