From 03484f20f1120b1685d37ca3b0d2523ae5bc9165 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Sep 2024 14:50:04 -0400 Subject: [PATCH] cleanup --- test/llama.py | 3 +-- test/test_utils.py | 7 +------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/test/llama.py b/test/llama.py index f04a3bdf..ec1578ae 100644 --- a/test/llama.py +++ b/test/llama.py @@ -10,7 +10,6 @@ JaXPipeline, hlo_opts, ) - import numpy as np import timeit from test_utils import * @@ -437,7 +436,7 @@ def sfn(x, weights, key_cache, value_cache): self.fn = partial(forward, config) self.name = "llama" - self.count = 1 # 100 if jax.default_backend() == "cpu" else 1000 + self.count = 100 if jax.default_backend() == "cpu" else 1000 self.revprimal = False self.AllPipelines = pipelines self.AllBackends = CurBackends diff --git a/test/test_utils.py b/test/test_utils.py index b34c4fba..8d037105 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -18,11 +18,7 @@ if jax.default_backend() != "cpu": devices = CurBackends -devices = ["gpu"] -CurBackends = ["gpu"] - AllBackends = ["cpu"] + devices -AllBackends = ["gpu"] # cpu"] + devices AllPipelines = [ ("JaX ", None, AllBackends), ("JaXPipe", JaXPipeline(), AllBackends), @@ -163,9 +159,8 @@ def harness(self, name, in_fn, ins, dins, douts): revstr = ( "rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" ) - print("all", self.AllBackends) + for backend in self.AllBackends: - print("backend", backend) ins_backend = [to_backend(x, backend) for x in ins] dins_backend = [to_backend(x, backend) for x in dins] douts_backend = [to_backend(x, backend) for x in douts]