Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
William Moses committed Sep 5, 2024
1 parent 25e487e commit 03484f2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 8 deletions.
3 changes: 1 addition & 2 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
JaXPipeline,
hlo_opts,
)

import numpy as np
import timeit
from test_utils import *
Expand Down Expand Up @@ -437,7 +436,7 @@ def sfn(x, weights, key_cache, value_cache):

self.fn = partial(forward, config)
self.name = "llama"
self.count = 1 # 100 if jax.default_backend() == "cpu" else 1000
self.count = 100 if jax.default_backend() == "cpu" else 1000
self.revprimal = False
self.AllPipelines = pipelines
self.AllBackends = CurBackends
Expand Down
7 changes: 1 addition & 6 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
if jax.default_backend() != "cpu":
devices = CurBackends

devices = ["gpu"]
CurBackends = ["gpu"]

AllBackends = ["cpu"] + devices
AllBackends = ["gpu"] # cpu"] + devices
AllPipelines = [
("JaX ", None, AllBackends),
("JaXPipe", JaXPipeline(), AllBackends),
Expand Down Expand Up @@ -163,9 +159,8 @@ def harness(self, name, in_fn, ins, dins, douts):
revstr = (
"rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")"
)
print("all", self.AllBackends)

for backend in self.AllBackends:
print("backend", backend)
ins_backend = [to_backend(x, backend) for x in ins]
dins_backend = [to_backend(x, backend) for x in dins]
douts_backend = [to_backend(x, backend) for x in douts]
Expand Down

0 comments on commit 03484f2

Please sign in to comment.