diff --git a/angler_control/controllers/robot_trajectory_controller/__init__.py b/angler_control/controllers/robot_trajectory_controller/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/angler_control/controllers/robot_trajectory_controller/base_multidof_joint_trajectory_controller.py b/angler_control/controllers/robot_trajectory_controller/base_multidof_joint_trajectory_controller.py deleted file mode 100644 index e92011c..0000000 --- a/angler_control/controllers/robot_trajectory_controller/base_multidof_joint_trajectory_controller.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2023, Evan Palmer -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -from abc import abstractmethod -from typing import Any - -from control_msgs.action import FollowJointTrajectory -from controllers.controller import BaseController -from rclpy.action import ActionServer -from rclpy.time import Time -from trajectory_msgs.msg import MultiDOFJointTrajectory, MultiDOFJointTrajectoryPoint - - -class BaseMultiDOFJointTrajectoryController(BaseController): - def __init__(self, controller_name: str) -> None: - super().__init__(controller_name) - - self.execute_trajectory_client = ActionServer( - self, - FollowJointTrajectory, - "/angler/action/execute_cartesian_trajectory", - self.execute_trajectory, - ) - - @abstractmethod - def execute_trajectory( - self, goal: FollowJointTrajectory.Goal - ) -> FollowJointTrajectory.Result: - ... - - -class Trajectory: - def __init__(self, trajectory: MultiDOFJointTrajectory | None = None) -> None: - if trajectory is not None: - self.trajectory = trajectory - self.start_time = Time( - seconds=trajectory.header.stamp.sec, - nanoseconds=trajectory.header.stamp.nanosec, - ) - else: - self.trajectory = MultiDOFJointTrajectory() - self.start_time = Time() - - self.time_before_trajectory = Time() - self.state_before_trajectory = MultiDOFJointTrajectoryPoint() - - def _set_point_before_trajectory_msg( - self, - current_time: Time, - current_point: MultiDOFJointTrajectoryPoint, - ) -> None: - self.time_before_trajectory = current_time - self.state_before_trajectory = current_point - - @staticmethod - def merge_trajectory_with_current_state( - current_time: Time, - state: MultiDOFJointTrajectoryPoint, - trajectory: MultiDOFJointTrajectory, - ) -> Any: - traj = Trajectory(trajectory) - traj._set_point_before_trajectory_msg(current_time, state) - - return traj - - def update(self, trajectory: MultiDOFJointTrajectory): - self.trajectory = trajectory - self.start_time = Time( - seconds=trajectory.header.stamp.sec, - nanoseconds=trajectory.header.stamp.nanosec, - ) - self.sampled_already = False - - def sample( - self, - t: Time, - start_point: MultiDOFJointTrajectoryPoint, - end_point: MultiDOFJointTrajectoryPoint, - ) -> MultiDOFJointTrajectoryPoint | None: - if len(self.trajectory.points) <= 0: - return None - - if not self.sampled_already: - if self.start_time.seconds_nanoseconds()[0] == 0: - self.start_time = t - self.sampled_already = True - - if t < self.time_before_trajectory: - return None - - output_state = MultiDOFJointTrajectoryPoint() - first_point_in_msg = self.trajectory.points[0] - first_point_timestamp = self.start_time + first_point_in_msg.time_from_start - - if t < first_point_timestamp: - if True: - ... - - def interpolate(self): - ... diff --git a/angler_control/controllers/robot_trajectory_controller/multidof_joint_trajectory_controller.py b/angler_control/controllers/robot_trajectory_controller/multidof_joint_trajectory_controller.py index 1769134..7582222 100644 --- a/angler_control/controllers/robot_trajectory_controller/multidof_joint_trajectory_controller.py +++ b/angler_control/controllers/robot_trajectory_controller/multidof_joint_trajectory_controller.py @@ -23,6 +23,7 @@ import rclpy from control_msgs.action import FollowJointTrajectory from controllers.controller import BaseController +from controllers.robot_trajectory_controller.trajectory import MultiDOFTrajectory from rclpy.action import ActionServer, GoalResponse from rclpy.timer import Timer @@ -38,11 +39,16 @@ def __init__(self, controller_name: str) -> None: """ super().__init__(controller_name) + self.declare_parameter("goal_tolerance", 0.2) + self.goal_handle: FollowJointTrajectory.Goal | None = None self.goal_result: FollowJointTrajectory.Result | None = None self._preempted = False self._running = False self._timer: Timer | None = None + self.goal_tolerance = ( + self.get_parameter("goal_tolerance").get_parameter_value().double_value + ) self.execute_trajectory_client = ActionServer( self, @@ -55,6 +61,14 @@ def __init__(self, controller_name: str) -> None: def handle_action_request_cb( self, goal: FollowJointTrajectory.Goal ) -> GoalResponse: + """Accept or reject the action request. + + Args: + goal: The action goal. + + Returns: + Whether or not the goal has been accepted. + """ if not self.armed: return GoalResponse.REJECT @@ -66,12 +80,26 @@ def handle_action_request_cb( @abstractmethod def on_update(self) -> None: + """Return if no goal has been received.""" if self.goal_handle is None: return + # TODO(evan): execute the trajectory + async def execute_trajectory_cb( self, goal: FollowJointTrajectory.Goal ) -> FollowJointTrajectory.Result: + """Execute a trajectory. + + Args: + goal: The goal with the desired trajectory to execute. + + Raises: + RuntimeError: Unhandled state. + + Returns: + The trajectory execution result. + """ # Wait for any previous goals to be cleaned up timer = self.create_rate(self.dt) while self._running and rclpy.ok(): @@ -85,6 +113,11 @@ async def execute_trajectory_cb( self._running = True self.goal_handle = goal self.goal_result = None + self.trajectory = MultiDOFTrajectory( + goal.multi_dof_trajectory, + self.state.multi_dof_joint_state, + self.get_clock().now(), + ) # Spin until we get a response from the controller or the goal is preempted/ # canceled @@ -112,6 +145,11 @@ async def execute_trajectory_cb( ) def preempt_current_goal(self) -> FollowJointTrajectory.Result: + """Preempt the current goal with a new goal. + + Returns: + The result of the goal that has been preempted. + """ result = FollowJointTrajectory.Result() result.error_code = FollowJointTrajectory.Result.INVALID_GOAL result.error_string = ( @@ -121,6 +159,11 @@ def preempt_current_goal(self) -> FollowJointTrajectory.Result: return result def cancel_current_goal(self) -> FollowJointTrajectory.Result: + """Cancel the current goal. + + Returns: + The result of the goal that has been cancelled. + """ result = FollowJointTrajectory.Result() result.error_code = FollowJointTrajectory.Result.INVALID_GOAL result.error_string = "The current action was cancelled by a client!" diff --git a/angler_control/controllers/robot_trajectory_controller/trajectory.py b/angler_control/controllers/robot_trajectory_controller/trajectory.py index 664f5d5..1453b25 100644 --- a/angler_control/controllers/robot_trajectory_controller/trajectory.py +++ b/angler_control/controllers/robot_trajectory_controller/trajectory.py @@ -18,20 +18,239 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from rclpy.time import Time +import numpy as np +from geometry_msgs.msg import Transform, Twist +from rclpy.time import Duration, Time +from scipy.interpolate import CubicHermiteSpline, interp1d +from scipy.spatial.transform import Rotation as R +from sensor_msgs.msg import MultiDOFJointState from trajectory_msgs.msg import MultiDOFJointTrajectory, MultiDOFJointTrajectoryPoint -class Trajectory: - def __init__(self, trajectory: MultiDOFJointTrajectory | None = None) -> None: - pass +class MultiDOFTrajectory: + """Wraps a MultiDOFJointTrajectory for sampling purposes. + + The MultiDOFTrajectory wrapper provides and interface for sampling points between + a discrete set of MultiDOFJointTrajectoryPoint messages. This helps provide + continuity between points. + """ + + def __init__( + self, + trajectory: MultiDOFJointTrajectory, + current_state: MultiDOFJointState, + current_time: Time, + ) -> None: + """Create a new MultiDOFJointTrajectory interface. + + Args: + trajectory: The trajectory to create an interface for. + current_state: The current state of the robot prior to beginning the + trajectory. + current_time: The current ROS timestamp. + """ + self.trajectory = trajectory + + # Convert the current state into a MultiDOFJointTrajectoryPoint for consistency + # Note that there is no velocity component in the MultiDOFJointState message + self.starting_state = MultiDOFJointTrajectoryPoint() + + self.starting_state.transforms = current_state.transforms + self.starting_state.velocities = current_state.twist + + self.starting_time = current_time + + def sample(self, t: Time) -> MultiDOFJointTrajectoryPoint | None: + """Sample the trajectory at a given timestamp. + + Args: + t: The timestamp to sample the trajectory at. + + Returns: + The point at the given timestamp if it exists; None, otherwise. + """ + if len(self.trajectory.points) <= 0: + return None + + # Attempting to sample before the initial state + if t < self.starting_time: + return None + + first_point = self.trajectory.points[0] # type: ignore + first_point_timestamp = first_point.time_from_start + self.starting_time + + if t < first_point_timestamp: + return self.interpolate( + self.starting_state, + first_point, + self.starting_time, + first_point_timestamp, + t, + ) + + for i in range(len(self.trajectory.points) - 1): + point = self.trajectory.points[i] # type: ignore + next_point = self.trajectory.points[i + 1] # type: ignore + + t0 = self.starting_time + point.time_from_start + t1 = self.starting_time + next_point.time_from_start + + if t0 <= t < t1: + return self.interpolate(point, next_point, t0, t1, t) + + # Return None by default because we didn't find a point in the trajectory + return None def interpolate( self, start_point: MultiDOFJointTrajectoryPoint, end_point: MultiDOFJointTrajectoryPoint, - ): - ... + start_time: Time, + end_time: Time, + sample_time: Time, + ) -> MultiDOFJointTrajectoryPoint: + """Interpolate between two points and sample the resulting function. + + This method currently only supports interpolation between positions and + velocities. Interpolation between points with accelerations is not yet + supported. + + Args: + start_point: The segment starting point. + end_point: The segment end point. + start_time: The timestamp of the first point in the segment. + end_time: The timestamp of the second point in the segment. + sample_time: The timestamp at which to sample. + + Returns: + _description_ + """ + duration_from_start: Duration = sample_time - start_time + duration_between_points: Duration = end_time - start_time + + has_velocity = start_point.velocities and end_point.velocities + + # Check if the sample time is before the start point or if the sample time is + # after the end point + if ( + duration_from_start.nanoseconds < 0 + or duration_from_start.nanoseconds > duration_between_points.nanoseconds + ): + has_velocity = False + + result = MultiDOFJointTrajectoryPoint() + result.time_from_start.nanosec = sample_time + + # Create a few helper functions to clean things up + def convert_tf_to_array(tf: Transform) -> np.ndarray: + return np.array( + [ + tf.translation.x, + tf.translation.y, + tf.translation.z, + *R.from_quat( + [tf.rotation.x, tf.rotation.y, tf.rotation.z, tf.rotation.w] + ).as_euler("xyz"), + ] + ) + + def convert_twist_to_array(twist: Twist) -> np.ndarray: + return np.array( + [ + twist.linear.x, + twist.linear.y, + twist.linear.z, + twist.angular.x, + twist.angular.z, + ] + ) + + if not has_velocity: + # Perform linear interpolation between points + for i in range(len(start_point.transforms)): + # Get the desired positions + start_tf: Transform = start_point.transforms[i] # type: ignore + end_tf: Transform = end_point.transforms[i] # type: ignore + + tfs = np.array( + [convert_tf_to_array(start_tf), convert_tf_to_array(end_tf)] + ) + t = np.array([start_time.nanoseconds, end_time.nanoseconds]) + + # Perform linear interpolation + interpolation = interp1d(t, tfs, kind="linear") + sample = interpolation(sample_time.nanoseconds) + + # Reconstruct the sample as a ROS message + sample_msg = Transform() + + ( + sample_msg.translation.x, + sample_msg.translation.y, + sample_msg.translation.z, + ) = sample[:3] + + ( + sample_msg.rotation.x, + sample_msg.rotation.y, + sample_msg.rotation.z, + sample_msg.rotation.w, + ) = R.from_euler("xyz", sample[3:]).as_quat() + + result.transforms.append(sample_msg) # type: ignore + elif has_velocity: + for i in range(len(start_point.transforms)): + # Get the desired positions + start_tf: Transform = start_point.transforms[i] # type: ignore + end_tf: Transform = end_point.transforms[i] # type: ignore + + # Get the desired velocities + start_vel = start_point.velocities[i] # type: ignore + end_vel = end_point.velocities[i] # type: ignore + + tfs = np.array( + [convert_tf_to_array(start_tf), convert_tf_to_array(end_tf)] + ) + vels = np.array( + [convert_twist_to_array(start_vel), convert_twist_to_array(end_vel)] + ) + t = np.array([start_time.nanoseconds, end_time.nanoseconds]) + + # Interpolate using a cubic hermite spline + spline = CubicHermiteSpline(t, tfs, vels) + sample_tf = spline(sample_time.nanoseconds) + sample_vel = spline(sample_time.nanoseconds, 1) + + # Now reconstruct the TF message + sample_tf_msg = Transform() + + ( + sample_tf_msg.translation.x, + sample_tf_msg.translation.y, + sample_tf_msg.translation.z, + ) = sample_tf[:3] + + ( + sample_tf_msg.rotation.x, + sample_tf_msg.rotation.y, + sample_tf_msg.rotation.z, + sample_tf_msg.rotation.w, + ) = R.from_euler("xyz", sample_tf[3:]).as_quat() + + # Reconstruct the velocity message + sample_twist_msg = Twist() + + ( + sample_twist_msg.linear.x, + sample_twist_msg.linear.y, + sample_twist_msg.linear.z, + sample_twist_msg.angular.x, + sample_twist_msg.angular.y, + sample_twist_msg.angular.z, + ) = sample_vel + + # Finish up by saving the samples + result.transforms.append(sample_tf_msg) # type: ignore + result.velocities.append(sample_twist_msg) # type: ignore - def sample(self, t: Time): - ... + return result