-
Notifications
You must be signed in to change notification settings - Fork 0
LLM Version #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
gioelemo
wants to merge
12
commits into
main
Choose a base branch
from
llm_moving
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
LLM Version #32
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
53dde23
Adapt Terra to start with custom agent position and angle
9484156
Fix visualization
d1c1701
Small fix in visualization
334fd86
Small fix
421b310
Fix config
e89cefd
Update terra/agent.py
gioelemo 9f7b5b2
Update terra/viz/game/game.py
gioelemo 9f40fcf
Update terra/viz/game/game.py
gioelemo 045b1e0
Support big maps
8dc1b51
small maps as default
5b38b0b
small fix
60868ed
removed comment
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,4 +17,5 @@ terra/digbench/ | |
data/* | ||
!data/custom/ | ||
!data/custom/** | ||
*.pkl | ||
*.pkl | ||
.DS_Store |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import NamedTuple | ||
from typing import NamedTuple, Optional, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -49,21 +49,82 @@ def new( | |
max_traversable_y: int, | ||
padding_mask: Array, | ||
action_map: Array, | ||
custom_pos: Optional[Tuple[int, int]] = None, | ||
custom_angle: Optional[int] = None, | ||
|
||
) -> tuple["Agent", jax.random.PRNGKey]: | ||
""" | ||
Create a new agent with specified parameters. | ||
|
||
Args: | ||
key: JAX random key | ||
env_cfg: Environment configuration | ||
max_traversable_x: Maximum traversable x coordinate | ||
max_traversable_y: Maximum traversable y coordinate | ||
padding_mask: Mask indicating obstacles | ||
custom_pos: Optional custom position (x, y) to place the agent | ||
custom_angle: Optional custom angle for the agent | ||
|
||
Returns: | ||
New agent instance and updated random key | ||
""" | ||
# Handle custom position or default based on config | ||
has_custom_args = (custom_pos is not None) or (custom_angle is not None) | ||
|
||
def use_custom_position(k): | ||
# Create position based on custom args or defaults | ||
temp_pos = IntMap(jnp.array(custom_pos)) if custom_pos is not None else IntMap(jnp.array([-1, -1])) | ||
temp_angle = jnp.full((1,), custom_angle, dtype=IntMap) if custom_angle is not None else jnp.full((1,), -1, dtype=IntMap) | ||
|
||
# Get default position for missing components | ||
def_pos, def_angle, _ = _get_top_left_init_state(k, env_cfg) | ||
|
||
# Combine custom and default values | ||
pos = jnp.where(jnp.any(temp_pos < 0), def_pos, temp_pos) | ||
angle = jnp.where(jnp.any(temp_angle < 0), def_angle, temp_angle) | ||
|
||
# Check validity and return result using jax.lax.cond | ||
valid = _validate_agent_position( | ||
pos, angle, env_cfg, padding_mask, | ||
env_cfg.agent.width, env_cfg.agent.height | ||
) | ||
|
||
# Define the true and false branches for jax.lax.cond | ||
def true_fn(_): | ||
return (pos, angle, k) | ||
|
||
def false_fn(_): | ||
return jax.lax.cond( | ||
env_cfg.agent.random_init_state, | ||
lambda k_inner: _get_random_init_state( | ||
k_inner, env_cfg, max_traversable_x, max_traversable_y, | ||
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height, | ||
), | ||
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg), | ||
k | ||
) | ||
|
||
# Use jax.lax.cond to handle the validity check | ||
return jax.lax.cond(valid, true_fn, false_fn, None) | ||
|
||
def use_default_position(k): | ||
# Use existing logic for random or top-left position | ||
return jax.lax.cond( | ||
env_cfg.agent.random_init_state, | ||
lambda k_inner: _get_random_init_state( | ||
k_inner, env_cfg, max_traversable_x, max_traversable_y, | ||
padding_mask, action_map, env_cfg.agent.width, env_cfg.agent.height, | ||
), | ||
lambda k_inner: _get_top_left_init_state(k_inner, env_cfg), | ||
k | ||
) | ||
|
||
# Use jax.lax.cond for JAX-compatible control flow | ||
pos_base, angle_base, key = jax.lax.cond( | ||
env_cfg.agent.random_init_state, | ||
lambda k: _get_random_init_state( | ||
k, | ||
env_cfg, | ||
max_traversable_x, | ||
max_traversable_y, | ||
padding_mask, | ||
action_map, | ||
env_cfg.agent.width, | ||
env_cfg.agent.height, | ||
), | ||
lambda k: _get_top_left_init_state(k, env_cfg), | ||
key, | ||
has_custom_args, | ||
use_custom_position, | ||
use_default_position, | ||
key | ||
) | ||
|
||
agent_state = AgentState( | ||
|
@@ -82,6 +143,58 @@ def new( | |
return Agent(agent_state=agent_state, width=width, height=height, moving_dumped_dirt=moving_dumped_dirt), key | ||
|
||
|
||
def _validate_agent_position( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function docstring should document the expected types and ranges for the parameters, particularly the Array types and what constitutes valid position/angle values. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
pos_base: Array, | ||
angle_base: Array, | ||
env_cfg: EnvConfig, | ||
padding_mask: Array, | ||
agent_width: int, | ||
agent_height: int, | ||
) -> Array: | ||
""" | ||
Validate if an agent position is valid (within bounds and not intersecting obstacles). | ||
|
||
Returns: | ||
JAX array with boolean value indicating if the position is valid | ||
""" | ||
map_width = padding_mask.shape[0] | ||
map_height = padding_mask.shape[1] | ||
|
||
# Check if position is within bounds | ||
max_center_coord = jnp.ceil( | ||
jnp.max(jnp.array([agent_width / 2 - 1, agent_height / 2 - 1])) | ||
).astype(IntMap) | ||
|
||
max_w = jnp.minimum(env_cfg.maps.edge_length_px, map_width) | ||
max_h = jnp.minimum(env_cfg.maps.edge_length_px, map_height) | ||
|
||
within_bounds = jnp.logical_and( | ||
jnp.logical_and(pos_base[0] >= max_center_coord, pos_base[0] < max_w - max_center_coord), | ||
jnp.logical_and(pos_base[1] >= max_center_coord, pos_base[1] < max_h - max_center_coord) | ||
) | ||
|
||
# Check if position intersects with obstacles | ||
def check_obstacle_intersection(_): | ||
agent_corners_xy = get_agent_corners( | ||
pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base | ||
) | ||
polygon_mask = compute_polygon_mask(agent_corners_xy, map_width, map_height) | ||
has_obstacle = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1)) | ||
return jnp.logical_not(has_obstacle) | ||
|
||
def return_false(_): | ||
return jnp.array(False) | ||
|
||
# Only check obstacles if we're within bounds (to avoid unnecessary computations) | ||
valid = jax.lax.cond( | ||
within_bounds, | ||
check_obstacle_intersection, | ||
return_false, | ||
None | ||
) | ||
|
||
return valid | ||
|
||
def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig): | ||
max_center_coord = jnp.ceil( | ||
jnp.max( | ||
|
@@ -174,4 +287,4 @@ def _check_intersection(): | |
), | ||
) | ||
|
||
return pos_base, angle_base, key | ||
return pos_base, angle_base, key |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sentinel value -1 for unset angle should be defined as a constant to improve code clarity.
Copilot uses AI. Check for mistakes.