Skip to content

Commit

Permalink
Explicitly set AutoFDO profile version in CompileOptions.
Browse files Browse the repository at this point in the history
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.

Testing: updated unit tests.
PiperOrigin-RevId: 555333728
  • Loading branch information
jax authors committed Aug 10, 2023
1 parent aac4cda commit eb076c4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
34 changes: 32 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@
'comma-separate list of integer device IDs.')


# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
# version. The default (-1) takes care of errors.
def get_latest_profile_version() -> int:
return -1


def get_compile_options(
num_replicas: int,
num_partitions: int,
Expand Down Expand Up @@ -165,12 +171,36 @@ def get_compile_options(
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path

if _DISABLE_MOST_OPTIMIZATIONS.value:

debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False

compile_options.profile_version = config.jax_xla_profile_version
# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
# set in get_latest_profile_version and use the return value if non-zero.
# If the function returns 0, set -1; this is an error.
# -1 indicates that no attempt should be made to retrieve the latest profile
# later on.
jax_xla_profile_version = config.jax_xla_profile_version
if jax_xla_profile_version > 0:
compile_options.profile_version = jax_xla_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using JAX XLA profile version %d from flag",
jax_xla_profile_version)
else:
fdo_profile_version = get_latest_profile_version()
if fdo_profile_version != 0:
compile_options.profile_version = fdo_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using XLA-AutoFDO profile version %d",
fdo_profile_version)
else:
no_profile_dont_retrieve = -1
compile_options.profile_version = no_profile_dont_retrieve
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")

return compile_options


Expand Down
41 changes: 41 additions & 0 deletions tests/xla_bridge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src.lib import xla_client as xc
from jax._src.interpreters import xla

from jax._src import config as jax_config
from jax._src.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
Expand Down Expand Up @@ -57,6 +58,46 @@ def test_set_fdo_profile(self):
compile_options.executable_build_options.fdo_profile, "test_profile"
)

def test_autofdo_profile(self):
# --jax_xla_profile_version takes precedence.
jax_flag_profile = 1
another_profile = 2
with jax_config.jax_xla_profile_version(jax_flag_profile):
with mock.patch.object(xb, "get_latest_profile_version",
side_effect=lambda: another_profile):
self.assertEqual(
xb.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
jax_flag_profile,
)

# Use whatever non-zero value the function get_latest_profile_version
# returns if --jax_xla_profile_version is not set.
profile_version = 1
with mock.patch.object(xb, "get_latest_profile_version",
side_effect=lambda: profile_version):
self.assertEqual(
xb.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
profile_version,
)

# If the function returns 0, something is wrong, so expect that we set
# profile_version to -1 instead to ensure that no attempt is made to
# retrieve the latest profile later.
error_return = 0
no_profile_dont_retrieve = -1
with mock.patch.object(xb, "get_latest_profile_version",
side_effect=lambda: error_return):
self.assertEqual(
xb.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
no_profile_dont_retrieve,
)

def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
Expand Down

0 comments on commit eb076c4

Please sign in to comment.