Skip to content

Commit 61e9346

Browse files
Merge pull request #39 from YuechengLiu/gym_compatibility
Gym compatibility for 'Quadrotor' Environment
2 parents aa50b7f + bfcfb4e commit 61e9346

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

rlschool/quadrotor/env.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
from math import floor, ceil
1818
from collections import namedtuple
19+
import gym
1920

2021
from rlschool.quadrotor.quadrotorsim import QuadrotorSim
2122

@@ -26,7 +27,7 @@
2627
NO_DISPLAY = True
2728

2829

29-
class Quadrotor(object):
30+
class Quadrotor(gym.Env):
3031
"""
3132
Quadrotor environment.
3233
@@ -41,6 +42,7 @@ class Quadrotor(object):
4142
map is a 100x100 flatten floor.
4243
simulator_conf (None|str): path to simulator config xml file.
4344
"""
45+
4446
def __init__(self,
4547
dt=0.01,
4648
nt=1000,
@@ -68,13 +70,11 @@ def __init__(self,
6870

6971
cfg_dict = self.simulator.get_config(simulator_conf)
7072
self.valid_range = cfg_dict['range']
71-
self.action_space = namedtuple(
72-
'action_space', ['shape', 'high', 'low', 'sample'])
73-
self.action_space.shape = [4]
74-
self.action_space.high = [cfg_dict['action_space_high']] * 4
75-
self.action_space.low = [cfg_dict['action_space_low']] * 4
76-
self.action_space.sample = Quadrotor.random_action(
77-
cfg_dict['action_space_low'], cfg_dict['action_space_high'], 4)
73+
self.action_space = gym.spaces.Box(
74+
low=np.array([cfg_dict['action_space_low']] * 4, dtype='float32'),
75+
high=np.array(
76+
[cfg_dict['action_space_high']] * 4, dtype='float32'),
77+
shape=[4])
7878

7979
self.body_velocity_keys = ['b_v_x', 'b_v_y', 'b_v_z']
8080
self.body_position_keys = ['b_x', 'b_y', 'b_z']
@@ -91,8 +91,7 @@ def __init__(self,
9191
len(self.flight_pose_keys) + len(self.barometer_keys)
9292
if self.task == 'velocity_control':
9393
obs_dim += len(self.task_velocity_control_keys)
94-
self.observation_space = namedtuple('observation_space', ['shape'])
95-
self.observation_space.shape = [obs_dim]
94+
self.observation_space = gym.Space(shape=[obs_dim], dtype='float32')
9695

9796
self.state = {}
9897
self.viewer = None
@@ -300,14 +299,6 @@ def load_map(map_file):
300299

301300
return np.array(map_lists)
302301

303-
@staticmethod
304-
def random_action(low, high, dim):
305-
@staticmethod
306-
def sample():
307-
act = np.random.random_sample((dim,))
308-
return (high - low) * act + low
309-
return sample
310-
311302

312303
if __name__ == '__main__':
313304
import sys

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
'trimesh>=3.2.39',
5353
'networkx>=2.2',
5454
'colour>=0.1.5',
55-
'scipy>=0.12.0'
55+
'scipy>=0.12.0',
56+
'gym==0.18.0',
5657
],
5758
zip_safe=False,
5859
)

0 commit comments

Comments
 (0)