Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ terra/digbench/
data/*
!data/custom/
!data/custom/**
*.pkl
*.pkl
.DS_Store
143 changes: 128 additions & 15 deletions terra/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -49,21 +49,82 @@ def new(
max_traversable_y: int,
padding_mask: Array,
action_map: Array,
custom_pos: Optional[Tuple[int, int]] = None,
custom_angle: Optional[int] = None,

) -> tuple["Agent", jax.random.PRNGKey]:
"""
Create a new agent with specified parameters.

Args:
key: JAX random key
env_cfg: Environment configuration
max_traversable_x: Maximum traversable x coordinate
max_traversable_y: Maximum traversable y coordinate
padding_mask: Mask indicating obstacles
custom_pos: Optional custom position (x, y) to place the agent
custom_angle: Optional custom angle for the agent

Returns:
New agent instance and updated random key
"""
# Handle custom position or default based on config
has_custom_args = (custom_pos is not None) or (custom_angle is not None)

def use_custom_position(k):
# Create position based on custom args or defaults
temp_pos = IntMap(jnp.array(custom_pos)) if custom_pos is not None else IntMap(jnp.array([-1, -1]))
temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), -1, dtype=IntMap)
Copy link
Preview

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentinel value -1 for unset angle should be defined as a constant to improve code clarity.

Suggested change
temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), -1, dtype=IntMap)
temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), UNSET_ANGLE, dtype=IntMap)

Copilot uses AI. Check for mistakes.


# Get default position for missing components
def_pos, def_angle, _ = _get_top_left_init_state(k, env_cfg)

# Combine custom and default values
pos = jnp.where(jnp.any(temp_pos < 0), def_pos, temp_pos)
angle = jnp.where(jnp.any(temp_angle < 0), def_angle, temp_angle)

# Check validity and return result using jax.lax.cond
valid = _validate_agent_position(
pos, angle, env_cfg, padding_mask,
env_cfg.agent.width, env_cfg.agent.height
)

# Define the true and false branches for jax.lax.cond
def true_fn(_):
return (pos, angle, k)

def false_fn(_):
return jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k_inner: _get_random_init_state(
k_inner, env_cfg, max_traversable_x, max_traversable_y,
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height,
),
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg),
k
)

# Use jax.lax.cond to handle the validity check
return jax.lax.cond(valid, true_fn, false_fn, None)

def use_default_position(k):
# Use existing logic for random or top-left position
return jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k_inner: _get_random_init_state(
k_inner, env_cfg, max_traversable_x, max_traversable_y,
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height,
),
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg),
k
)

# Use jax.lax.cond for JAX-compatible control flow
pos_base, angle_base, key = jax.lax.cond(
env_cfg.agent.random_init_state,
lambda k: _get_random_init_state(
k,
env_cfg,
max_traversable_x,
max_traversable_y,
padding_mask,
action_map,
env_cfg.agent.width,
env_cfg.agent.height,
),
lambda k: _get_top_left_init_state(k, env_cfg),
key,
has_custom_args,
use_custom_position,
use_default_position,
key
)

agent_state = AgentState(
Expand All @@ -82,6 +143,58 @@ def new(
return Agent(agent_state=agent_state, width=width, height=height, moving_dumped_dirt=moving_dumped_dirt), key


def _validate_agent_position(
Copy link
Preview

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function docstring should document the expected types and ranges for the parameters, particularly the Array types and what constitutes valid position/angle values.

Copilot uses AI. Check for mistakes.

pos_base: Array,
angle_base: Array,
env_cfg: EnvConfig,
padding_mask: Array,
agent_width: int,
agent_height: int,
) -> Array:
"""
Validate if an agent position is valid (within bounds and not intersecting obstacles).

Returns:
JAX array with boolean value indicating if the position is valid
"""
map_width = padding_mask.shape[0]
map_height = padding_mask.shape[1]

# Check if position is within bounds
max_center_coord = jnp.ceil(
jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1]))
).astype(IntMap)

max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width)
max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height)

within_bounds = jnp.logical_and(
jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord),
jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord)
)

# Check if position intersects with obstacles
def check_obstacle_intersection(_):
agent_corners_xy = get_agent_corners(
pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base
)
polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height)
has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1))
return jnp.logical_not(has_obstacle)

def return_false(_):
return jnp.array(False)

# Only check obstacles if we're within bounds (to avoid unnecessary computations)
valid = jax.lax.cond(
within_bounds,
check_obstacle_intersection,
return_false,
None
)

return valid

def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig):
max_center_coord = jnp.ceil(
jnp.max(
Expand Down Expand Up @@ -174,4 +287,4 @@ def _check_intersection():
),
)

return pos_base, angle_base, key
return pos_base, angle_base, key
76 changes: 38 additions & 38 deletions terra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,47 +176,47 @@ class CurriculumGlobalConfig(NamedTuple):
# NOTE: all maps need to have the same size
levels = [
{
"maps_path": "terra/foundations",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
{
"maps_path": "terra/trenches/single",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/trenches/double",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/trenches/double_diagonal",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/foundations",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
{
"maps_path": "terra/trenches/triple_diagonal",
"max_steps_in_episode": 400,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": True,
},
{
"maps_path": "terra/foundations_large",
"max_steps_in_episode": 500,
"maps_path": "foundations",
"max_steps_in_episode": 1000,
"rewards_type": RewardsType.DENSE,
"apply_trench_rewards": False,
},
# {
# "maps_path": "trenches/single",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "trenches/double",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "terra/trenches/double_diagonal",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "trenches/triple",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": False,
# },
# {
# "maps_path": "trenches/triple_diagonal",
# "max_steps_in_episode": 400,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": True,
# },
# {
# "maps_path": "terra/foundations_large",
# "max_steps_in_episode": 500,
# "rewards_type": RewardsType.DENSE,
# "apply_trench_rewards": False,
# },
]


Expand Down
54 changes: 42 additions & 12 deletions terra/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from functools import partial
from typing import NamedTuple
from typing import Any, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -39,6 +40,7 @@ def new(
n_envs_x: int = 1,
n_envs_y: int = 1,
display: bool = False,
agent_config_override: Optional[dict[str, Any]] = None,
) -> "TerraEnv":
re = None
tile_size_rendering = MAP_TILES // maps_size_px
Expand Down Expand Up @@ -68,6 +70,8 @@ def new(
n_envs_x=n_envs_x,
n_envs_y=n_envs_y,
display=display,
agent_config=agent_config_override if agent_config_override is not None else None,

)
return TerraEnv(rendering_engine=re)

Expand All @@ -82,6 +86,8 @@ def reset(
dumpability_mask_init: Array,
action_map: Array,
env_cfg: EnvConfig,
custom_pos: Optional[Tuple[int, int]] = None,
custom_angle: Optional[int] = None,
) -> tuple[State, dict[str, Array]]:
"""
Resets the environment using values from config files, and a seed.
Expand All @@ -95,6 +101,8 @@ def reset(
trench_type,
dumpability_mask_init,
action_map,
custom_pos,
custom_angle,
)
state = self.wrap_state(state)

Expand Down Expand Up @@ -154,7 +162,7 @@ def render_obs_pygame(
Renders the environment at a given observation.
"""
if info is not None:
target_tiles = info["target_tiles"]
target_tiles = info.get("target_tiles", None)
else:
target_tiles = None

Expand All @@ -169,6 +177,7 @@ def render_obs_pygame(
loaded=obs["agent_state"][..., [5]],
target_tiles=target_tiles,
generate_gif=generate_gif,
info=info,
)

@partial(jax.jit, static_argnums=(0,))
Expand Down Expand Up @@ -360,7 +369,9 @@ def _get_map(self, maps_buffer_keys: jax.random.PRNGKey, env_cfgs: EnvConfig):
return jax.vmap(self.maps_buffer.get_map)(maps_buffer_keys, env_cfgs)

@partial(jax.jit, static_argnums=(0,))
def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey) -> State:
def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey,
custom_pos: Optional[Tuple[int, int]] = None,
custom_angle: Optional[int] = None) -> State:
env_cfgs = self.curriculum_manager.reset_cfgs(env_cfgs)
env_cfgs = self.update_env_cfgs(env_cfgs)
(
Expand All @@ -372,17 +383,36 @@ def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey) -> State:
action_maps,
new_rng_key,
) = self._get_map_init(rng_key, env_cfgs)
timestep = jax.vmap(self.terra_env.reset)(
rng_key,
target_maps,
padding_masks,
trench_axes,
trench_type,
dumpability_mask_init,
action_maps,
env_cfgs,
)
timestep = jax.vmap(
self.terra_env.reset,
in_axes=(0, 0, 0, 0, 0, 0, 0, 0, None, None)
)(
rng_key,
target_maps,
padding_masks,
trench_axes,
trench_type,
dumpability_mask_init,
action_maps,
env_cfgs,
custom_pos,
custom_angle,
)
return timestep

@property
def actions_size(self) -> int:
"""
Number of actions played at every env step.
"""
return self.num_actions

@property
def num_actions(self) -> int:
"""
Total number of actions
"""
return self.batch_cfg.action_type.get_num_actions()

@partial(jax.jit, static_argnums=(0,))
def step(
Expand Down
Loading