Skip to content

Commit

Permalink
BUGFIX: Removing torch dependency from testing. Proper Env seeding po…
Browse files Browse the repository at this point in the history
…st unpickling. Better handling for future gym versions
  • Loading branch information
vikashplus committed Nov 27, 2023
1 parent 4c3c418 commit 7abd6c8
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions robohive/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,21 @@
import numpy as np
import pickle
import copy
import torch.testing
# import torch.testing
import os
from flatten_dict import flatten

def assert_close(prm1, prm2, atol=1e-05, rtol=1e-08):
if prm1 is None and prm2 is None:
return True
elif isinstance(prm1,dict) and isinstance(prm2, dict):
prm1_dict = flatten(prm1)
prm2_dict = flatten(prm2)
for key in prm1_dict.keys():
assert_close(prm1_dict[key], prm2_dict[key], atol=atol, rtol=rtol)
else:
np.testing.assert_allclose(prm1, prm2, atol=atol, rtol=rtol)
# torch.testing.assert_close(prm1, prm2, atol=atol, rtol=rtol)

class TestEnvs(unittest.TestCase):

Expand All @@ -35,14 +48,15 @@ def check_env(self, environment_id, input_seed):
# test init
env1 = gym.make(environment_id, seed=input_seed)
assert env1.get_input_seed() == input_seed
# test reset
env1.env.reset()
# test reseed and reset
env1.seed(input_seed)
reset_obs1 = env1.env.reset()

# step
u = 0.01*np.random.uniform(low=0, high=1, size=env1.env.sim.model.nu) # small controls
obs1, rwd1, done1, infos1 = env1.env.step(u.copy())
obs1, rwd1, done1, *_, infos1 = env1.env.step(u.copy())
infos1 = copy.deepcopy(infos1) #info points to internal variables.
proprio1 = env1.env.get_proprioception()
proprio1_t, proprio1_vec, proprio1_dict = env1.env.get_proprioception()
extero1 = env1.env.get_exteroception()
assert len(obs1>0)
# assert len(rwd1>0)
Expand All @@ -57,26 +71,31 @@ def check_env(self, environment_id, input_seed):

# serialize / deserialize env ------------
env2 = pickle.loads(pickle.dumps(env1))
# test reset
env2.reset()
# test seed
assert env2.get_input_seed() == input_seed
assert env1.get_input_seed() == env2.get_input_seed(), {env1.get_input_seed(), env2.get_input_seed()}
# check input output spaces
assert env1.action_space == env2.action_space, (env1.action_space, env2.action_space)
assert env1.observation_space == env2.observation_space, (env1.observation_space, env2.observation_space)

# test reseed and reset
env2.seed(input_seed)
reset_obs2 = env2.env.reset()
assert_close(reset_obs1, reset_obs2)

# step
obs2, rwd2, done2, infos2 = env2.env.step(u)
obs2, rwd2, done2, *_, infos2 = env2.env.step(u)
infos2 = copy.deepcopy(infos2)
proprio2 = env2.env.get_proprioception()
proprio2_t, proprio2_vec, proprio2_dict = env2.env.get_proprioception()
extero2 = env2.env.get_exteroception()
torch.testing.assert_close(obs1, obs2)
torch.testing.assert_close(proprio1, proprio2)
torch.testing.assert_close(extero1, extero2, atol=2, rtol=0.04)
torch.testing.assert_close(rwd1, rwd2)

assert_close(obs1, obs2)
assert_close(proprio1_vec, proprio2_vec)#, f"Difference in Proprio: {proprio1_vec-proprio2_vec}"
assert_close(extero1, extero2, atol=2, rtol=0.04)#, f"Difference in Extero {extero1}, {extero2}"
assert_close(rwd1, rwd2)#, "Difference in Rewards"
assert (done1==done2), (done1, done2)
assert len(infos1)==len(infos2), (infos1, infos2)
torch.testing.assert_close(infos1, infos2)
assert_close(infos1, infos2)
# reset
env2.reset()

Expand Down

0 comments on commit 7abd6c8

Please sign in to comment.