Skip to content

Commit

Permalink
Merge branch 'py311' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Oct 16, 2023
2 parents 8ca8cd0 + f7056cb commit b10e741
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion EventStream/evaluation/general_generative_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GenerateConfig:

do_overwrite: bool = False

optimization_config: OptimizationConfig = OptimizationConfig()
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())

task_df_name: str | None = None

Expand Down
2 changes: 1 addition & 1 deletion EventStream/transformer/lightning_modules/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class FinetuneConfig:
},
}
)
optimization_config: OptimizationConfig = OptimizationConfig()
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())
data_config: dict[str, Any] | None = dataclasses.field(
default_factory=lambda: {
**{k: None for k in PytorchDatasetConfig().to_dict().keys()},
Expand Down
14 changes: 9 additions & 5 deletions EventStream/transformer/lightning_modules/generative_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,16 @@ class PretrainConfig:
},
}
)
optimization_config: OptimizationConfig = OptimizationConfig()
data_config: PytorchDatasetConfig = PytorchDatasetConfig()
pretraining_metrics_config: MetricsConfig = MetricsConfig(
include_metrics={Split.TRAIN: {MetricCategories.LOSS_PARTS: True}},
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())
data_config: PytorchDatasetConfig = dataclasses.field(default_factory=lambda: PytorchDatasetConfig())
pretraining_metrics_config: MetricsConfig = dataclasses.field(
default_factory=lambda: MetricsConfig(
include_metrics={Split.TRAIN: {MetricCategories.LOSS_PARTS: True}},
)
)
final_validation_metrics_config: MetricsConfig = dataclasses.field(
default_factory=lambda: MetricsConfig(do_skip_all_metrics=False)
)
final_validation_metrics_config: MetricsConfig = MetricsConfig(do_skip_all_metrics=False)

trainer_config: dict[str, Any] = dataclasses.field(
default_factory=lambda: {
Expand Down

0 comments on commit b10e741

Please sign in to comment.