16
16
import numpy as np
17
17
from math import floor , ceil
18
18
from collections import namedtuple
19
+ import gym
19
20
20
21
from rlschool .quadrotor .quadrotorsim import QuadrotorSim
21
22
26
27
NO_DISPLAY = True
27
28
28
29
29
- class Quadrotor (object ):
30
+ class Quadrotor (gym . Env ):
30
31
"""
31
32
Quadrotor environment.
32
33
@@ -41,6 +42,7 @@ class Quadrotor(object):
41
42
map is a 100x100 flatten floor.
42
43
simulator_conf (None|str): path to simulator config xml file.
43
44
"""
45
+
44
46
def __init__ (self ,
45
47
dt = 0.01 ,
46
48
nt = 1000 ,
@@ -68,13 +70,11 @@ def __init__(self,
68
70
69
71
cfg_dict = self .simulator .get_config (simulator_conf )
70
72
self .valid_range = cfg_dict ['range' ]
71
- self .action_space = namedtuple (
72
- 'action_space' , ['shape' , 'high' , 'low' , 'sample' ])
73
- self .action_space .shape = [4 ]
74
- self .action_space .high = [cfg_dict ['action_space_high' ]] * 4
75
- self .action_space .low = [cfg_dict ['action_space_low' ]] * 4
76
- self .action_space .sample = Quadrotor .random_action (
77
- cfg_dict ['action_space_low' ], cfg_dict ['action_space_high' ], 4 )
73
+ self .action_space = gym .spaces .Box (
74
+ low = np .array ([cfg_dict ['action_space_low' ]] * 4 , dtype = 'float32' ),
75
+ high = np .array (
76
+ [cfg_dict ['action_space_high' ]] * 4 , dtype = 'float32' ),
77
+ shape = [4 ])
78
78
79
79
self .body_velocity_keys = ['b_v_x' , 'b_v_y' , 'b_v_z' ]
80
80
self .body_position_keys = ['b_x' , 'b_y' , 'b_z' ]
@@ -91,8 +91,7 @@ def __init__(self,
91
91
len (self .flight_pose_keys ) + len (self .barometer_keys )
92
92
if self .task == 'velocity_control' :
93
93
obs_dim += len (self .task_velocity_control_keys )
94
- self .observation_space = namedtuple ('observation_space' , ['shape' ])
95
- self .observation_space .shape = [obs_dim ]
94
+ self .observation_space = gym .Space (shape = [obs_dim ], dtype = 'float32' )
96
95
97
96
self .state = {}
98
97
self .viewer = None
@@ -300,14 +299,6 @@ def load_map(map_file):
300
299
301
300
return np .array (map_lists )
302
301
303
- @staticmethod
304
- def random_action (low , high , dim ):
305
- @staticmethod
306
- def sample ():
307
- act = np .random .random_sample ((dim ,))
308
- return (high - low ) * act + low
309
- return sample
310
-
311
302
312
303
if __name__ == '__main__' :
313
304
import sys
0 commit comments