diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 46bd4a58104c..0f2d6c49c749 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -34,7 +34,6 @@ def __init__( def forward( self, - model=None, data: torch.Tensor = None, hidden_states: torch.Tensor = None, stage_index=None, @@ -622,10 +621,7 @@ def criterion_base(x, *args, **kwargs): stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) model_pp._forward = model_pp.forward - # model_pp.forward = MethodType( - # partial(model_pp._forward, stage_mgr=stage_manager), - # model_pp, - # ) + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) # init optimizer