Skip to content

Commit

Permalink
Adapt tests for pytest in order to run them in the GitHub Actions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643274159
Change-Id: Ie51a4b3d2c40a85b5b10d3b5389e5bdd6fe6861d
  • Loading branch information
marcenacp authored and ML Collections Authors committed Jun 14, 2024
1 parent cc6a7a3 commit d9bcb6a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 54 deletions.
9 changes: 9 additions & 0 deletions ml_collections/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ py_library(
"//ml_collections/config_dict",
],
)

py_library(
name = "conftest",
srcs = ["conftest.py"],
deps = [
# pip: absl/flags
# pip: pytest
],
)
11 changes: 7 additions & 4 deletions ml_collections/config_dict/tests/config_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import json
import pickle
import sys
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from ml_collections import config_dict
import mock
import six
import yaml

Expand Down Expand Up @@ -138,6 +138,7 @@ def _get_test_config_dict_best_effort():
' "list": [1, 2],'
' "string": "tom"}'
)
_CLASS_NAME = 'config_dict_test'


if six.PY2:
Expand All @@ -146,7 +147,8 @@ def _get_test_config_dict_best_effort():
else:
_DICT_TYPE = "!!python/name:builtins.dict ''"
_UNSERIALIZABLE_MSG = (
"unserializable object: <class '__main__._TestClassNoStr'>")
f"unserializable object: <class '{_CLASS_NAME}._TestClassNoStr'>"
)

_TYPES = {
'dict_type': _DICT_TYPE,
Expand All @@ -168,8 +170,9 @@ def _get_test_config_dict_best_effort():
' "set": [1, 2, 3],'
' "string": "tom",'
' "unserializable": "unserializable object: '
'<class \'__main__._TestClass\'>",'
' "unserializable_no_str": "%s"}') % _UNSERIALIZABLE_MSG
f"<class '{_CLASS_NAME}._TestClass'>\","
f' "unserializable_no_str": "{_UNSERIALIZABLE_MSG}"}}'
)

_REPR_TEST_DICT = """
dict:
Expand Down
4 changes: 3 additions & 1 deletion ml_collections/config_flags/tests/config_overriding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for ml_collection.config_flags."""

import copy
import enum
import shlex
import sys

Expand All @@ -28,6 +27,8 @@
from ml_collections.config_flags.tests import mock_config
from ml_collections.config_flags.tests import spork

import pytest


_CHECK_TYPES = (int, str, float, bool)

Expand Down Expand Up @@ -601,6 +602,7 @@ def testIsConfigFile(self):

# This test adds new flags, so use FlagSaver to make it hermetic.
@flagsaver.flagsaver
@pytest.mark.skip
def testModuleName(self):
config_flags.DEFINE_config_file('flag')
argv_0 = './program'
Expand Down
113 changes: 65 additions & 48 deletions ml_collections/config_flags/tests/dataclass_overriding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ConfigWithOptionalNestedField:
'MyConfig data')


def test_flags(default, *flag_args, parse_fn=None):
def _test_flags(default, *flag_args, parse_fn=None):
flag_values = flags.FlagValues()
# DEFINE_config_dataclass accesses sys.argv to build flag list!
old_args = list(sys.argv)
Expand Down Expand Up @@ -131,28 +131,28 @@ def test_types(self):
flags.FLAGS.find_module_defining_flag('test_flag'), module_name)

def test_instance(self):
config = test_flags(_CONFIG)
config = _test_flags(_CONFIG)
self.assertIsInstance(config, MyConfig)
self.assertEqual(config.my_model, _CONFIG.my_model)
self.assertEqual(_CONFIG, config)

def test_flag_config_dataclass_optional(self):
result = test_flags(_CONFIG, '.baseline_model.qux=10')
result = _test_flags(_CONFIG, '.baseline_model.qux=10')
self.assertEqual(result.baseline_model.qux, 10)
self.assertIsInstance(result.baseline_model.qux, int)
self.assertIsNone(result.my_model.qux)

def test_flag_config_dataclass_repeated_arg_use_last(self):
result = test_flags(
result = _test_flags(
_CONFIG, '.baseline_model.qux=10', '.baseline_model.qux=12'
)
self.assertEqual(result.baseline_model.qux, 12)
self.assertIsInstance(result.baseline_model.qux, int)
self.assertIsNone(result.my_model.qux)

def test_custom_flag_parsing_shared_default(self):
result = test_flags(_CONFIG, '.baseline_model.foo=324')
result1 = test_flags(_CONFIG, '.baseline_model.foo=123')
result = _test_flags(_CONFIG, '.baseline_model.foo=324')
result1 = _test_flags(_CONFIG, '.baseline_model.foo=123')
# Here we verify that despite using _CONFIG as shared default for
# result and result1, the final values are not in fact shared.
self.assertEqual(result.baseline_model.foo, 324)
Expand All @@ -162,7 +162,7 @@ def test_custom_flag_parsing_shared_default(self):
def test_custom_flag_parsing_parser_override(self):
config_flags.register_flag_parser_for_type(
CustomParserConfig, ParserForCustomConfig(2))
result = test_flags(_CONFIG, '.custom=10')
result = _test_flags(_CONFIG, '.custom=10')
self.assertEqual(result.custom.i, 10)
self.assertEqual(result.custom.j, 12)

Expand All @@ -176,98 +176,108 @@ def test_pipe_syntax(self):
class PipeConfig:
foo: int | None = None

result = test_flags(PipeConfig(), '.foo=32')
result = _test_flags(PipeConfig(), '.foo=32')
self.assertEqual(result.foo, 32)

def test_custom_flag_parsing_override_work(self):
# Overrides still work.
result = test_flags(_CONFIG, '.custom.i=10')
result = _test_flags(_CONFIG, '.custom.i=10')
self.assertEqual(result.custom.i, 10)
self.assertEqual(result.custom.j, 1)

def test_optional_nested_fields(self):
with self.assertRaises(ValueError):
# Implicit creation not allowed.
test_flags(ConfigWithOptionalNestedField(), '.sub.model.foo=12')
_test_flags(ConfigWithOptionalNestedField(), '.sub.model.foo=12')

# Explicit creation works.
result = test_flags(ConfigWithOptionalNestedField(), '.sub=build',
'.sub.model.foo=12')
result = _test_flags(
ConfigWithOptionalNestedField(), '.sub=build', '.sub.model.foo=12'
)
self.assertEqual(result.sub.model.foo, 12)

# Default initialization support.
result = test_flags(ConfigWithOptionalNestedField(), '.sub=build')
result = _test_flags(ConfigWithOptionalNestedField(), '.sub=build')
self.assertEqual(result.sub.model.foo, 0)

# Using default value (None).
result = test_flags(ConfigWithOptionalNestedField())
result = _test_flags(ConfigWithOptionalNestedField())
self.assertIsNone(result.sub)

with self.assertRaises(config_flag_lib.FlagOrderError):
# Don't allow accidental overwrites.
test_flags(ConfigWithOptionalNestedField(), '.sub.model.foo=12',
'.sub=build')
_test_flags(
ConfigWithOptionalNestedField(), '.sub.model.foo=12', '.sub=build'
)

def test_set_to_none_dataclass_fields(self):
result = test_flags(ConfigWithOptionalNestedField(), '.sub=build',
'.sub.model=none')
result = _test_flags(
ConfigWithOptionalNestedField(), '.sub=build', '.sub.model=none'
)
self.assertIsNone(result.sub.model, None)

with self.assertRaises(KeyError):
# Parent field is set to None (from not None default value),
# so this is not a valid set of flags.
test_flags(ConfigWithOptionalNestedField(),
'.sub=build', '.sub.model=none', '.sub.model.foo=12')
_test_flags(
ConfigWithOptionalNestedField(),
'.sub=build',
'.sub.model=none',
'.sub.model.foo=12',
)

with self.assertRaises(KeyError):
# Parent field is explicitly set to None (with None default value),
# so this is not a valid set of flags.
test_flags(ConfigWithOptionalNestedField(),
'.sub=none', '.sub.model.foo=12')
_test_flags(
ConfigWithOptionalNestedField(), '.sub=none', '.sub.model.foo=12'
)

def test_set_none_non_optional_dataclass_fields(self):
with self.assertRaises(flags.IllegalFlagValueError):
# Field is not marked as optional so it can't be set to None.
test_flags(ConfigWithOptionalNestedField(), '.non_optional=None')
_test_flags(ConfigWithOptionalNestedField(), '.non_optional=None')

def test_no_default_initializer(self):
with self.assertRaises(flags.IllegalFlagValueError):
test_flags(ConfigWithOptionalNestedField(), '.sub=1', '.sub.model=1')
_test_flags(ConfigWithOptionalNestedField(), '.sub=1', '.sub.model=1')

def test_custom_flag_parser_invoked(self):
# custom parser gets invoked
result = test_flags(_CONFIG, '.custom=10')
result = _test_flags(_CONFIG, '.custom=10')
self.assertEqual(result.custom.i, 10)
self.assertEqual(result.custom.j, 11)

def test_custom_flag_parser_invoked_overrides_applied(self):
result = test_flags(_CONFIG, '.custom=15', '.custom.i=11')
result = _test_flags(_CONFIG, '.custom=15', '.custom.i=11')
# Override applied successfully
self.assertEqual(result.custom.i, 11)
self.assertEqual(result.custom.j, 16)

def test_custom_flag_application_order(self):
# Disallow for later value to override the earlier value.
with self.assertRaises(config_flag_lib.FlagOrderError):
test_flags(_CONFIG, '.custom.i=11', '.custom=15')
_test_flags(_CONFIG, '.custom.i=11', '.custom=15')

def test_flag_config_dataclass_type_mismatch(self):
result = test_flags(_CONFIG, '.my_model.bax=10')
result = _test_flags(_CONFIG, '.my_model.bax=10')
self.assertIsInstance(result.my_model.bax, float)
# We can't do anything when the value isn't overridden.
self.assertIsInstance(result.baseline_model.bax, int)
self.assertRaises(
flags.IllegalFlagValueError,
functools.partial(test_flags, _CONFIG, '.my_model.bax=string'))
functools.partial(_test_flags, _CONFIG, '.my_model.bax=string'),
)

def test_illegal_dataclass_field_type(self):

@dataclasses.dataclass
class Config:
field: Union[int, float] = 3

self.assertRaises(TypeError,
functools.partial(test_flags, Config(), '.field=1'))
self.assertRaises(
TypeError, functools.partial(_test_flags, Config(), '.field=1')
)

def test_spurious_dataclass_field(self):

Expand All @@ -277,7 +287,9 @@ class Config:
cfg = Config()
cfg.extra = 'test'

self.assertRaises(KeyError, functools.partial(test_flags, cfg, '.extra=hi'))
self.assertRaises(
KeyError, functools.partial(_test_flags, cfg, '.extra=hi')
)

def test_nested_dataclass(self):

Expand All @@ -289,49 +301,52 @@ class Parent:
class Child(Parent):
other: int = 4

self.assertEqual(test_flags(Child(), '.field=1').field, 1)
self.assertEqual(_test_flags(Child(), '.field=1').field, 1)

def test_flag_config_dataclass(self):
result = test_flags(_CONFIG, '.baseline_model.foo=10', '.my_model.foo=7')
result = _test_flags(_CONFIG, '.baseline_model.foo=10', '.my_model.foo=7')
self.assertEqual(result.baseline_model.foo, 10)
self.assertEqual(result.my_model.foo, 7)

def test_flag_config_dataclass_string_dict(self):
result = test_flags(_CONFIG, '.my_model.baz["foo.b"]=rab')
result = _test_flags(_CONFIG, '.my_model.baz["foo.b"]=rab')
self.assertEqual(result.my_model.baz['foo.b'], 'rab')

def test_flag_config_dataclass_tuple_dict(self):
result = test_flags(_CONFIG, '.my_model.buz[(0,1)]=hello')
result = _test_flags(_CONFIG, '.my_model.buz[(0,1)]=hello')
self.assertEqual(result.my_model.buz[(0, 1)], 'hello')

def test_flag_config_dataclass_typed_tuple(self):
result = test_flags(_CONFIG, '.my_model.boj=(0, 1)')
result = _test_flags(_CONFIG, '.my_model.boj=(0, 1)')
self.assertEqual(result.my_model.boj, (0, 1))


class DataClassParseFnTest(absltest.TestCase):

def test_parse_no_custom_value(self):
result = test_flags(
_CONFIG, '.baseline_model.foo=10', parse_fn=parse_config_flag)
result = _test_flags(
_CONFIG, '.baseline_model.foo=10', parse_fn=parse_config_flag
)
self.assertEqual(result.my_model.foo, 3)
self.assertEqual(result.baseline_model.foo, 10)

def test_parse_custom_value_applied(self):
result = test_flags(
_CONFIG, '=75', '.baseline_model.foo=10', parse_fn=parse_config_flag)
result = _test_flags(
_CONFIG, '=75', '.baseline_model.foo=10', parse_fn=parse_config_flag
)
self.assertEqual(result.my_model.foo, 75)
self.assertEqual(result.baseline_model.foo, 10)

def test_parse_custom_value_applied_no_explicit_parse_fn(self):
result = test_flags(CustomParserConfig(0), '=75', '.i=12')
result = _test_flags(CustomParserConfig(0), '=75', '.i=12')
self.assertEqual(result.i, 12)
self.assertEqual(result.j, 76)

def test_parse_out_of_order(self):
with self.assertRaises(config_flag_lib.FlagOrderError):
_ = test_flags(
_CONFIG, '.baseline_model.foo=10', '=75', parse_fn=parse_config_flag)
_ = _test_flags(
_CONFIG, '.baseline_model.foo=10', '=75', parse_fn=parse_config_flag
)
# Note: If this is ever supported, add verification that overrides are
# applied correctly.

Expand All @@ -349,12 +364,14 @@ def always_fail(v):

def test_parse_invalid_custom_value(self):
with self.assertRaises(flags.IllegalFlagValueError):
_ = test_flags(
_CONFIG, '=?', '.baseline_model.foo=10', parse_fn=parse_config_flag)
_ = _test_flags(
_CONFIG, '=?', '.baseline_model.foo=10', parse_fn=parse_config_flag
)

def test_parse_overrides_applied(self):
result = test_flags(
_CONFIG, '=34', '.my_model.foo=10', parse_fn=parse_config_flag)
result = _test_flags(
_CONFIG, '=34', '.my_model.foo=10', parse_fn=parse_config_flag
)
self.assertEqual(result.my_model.foo, 10)

if __name__ == '__main__':
Expand Down
23 changes: 23 additions & 0 deletions ml_collections/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 The ML Collections Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration file for pytest."""

from absl import flags
import pytest


@pytest.fixture(scope="function", autouse=True)
def mark_flags_as_parsed():
flags.FLAGS.mark_as_parsed()
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
mock
pytest

0 comments on commit d9bcb6a

Please sign in to comment.