Skip to content

Commit

Permalink
Merge pull request #143 from davidprueser/tests
Browse files Browse the repository at this point in the history
Added views to orm
  • Loading branch information
tomsch420 authored Apr 10, 2024
2 parents 5e1a3f1 + 23ab36a commit 9f0d35b
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/neem_interface_python
Empty file removed src/pycram/orm/queries/__init__.py
Empty file.
51 changes: 0 additions & 51 deletions src/pycram/orm/queries/queries.py

This file was deleted.

149 changes: 149 additions & 0 deletions src/pycram/orm/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from sqlalchemy.orm import declarative_base
from typing_extensions import Union
import sqlalchemy.orm
from sqlalchemy import table, inspect, event, select, engine, MetaData, Select, TableClause, ExecutableDDLElement
from sqlalchemy.ext.compiler import compiles
from pycram.orm.action_designator import PickUpAction
from pycram.orm.base import Position, RobotState, Pose, Base, Quaternion
from pycram.orm.object_designator import Object
from pycram.orm.task import TaskTreeNode


class CreateView(ExecutableDDLElement):
"""
Class that is used to create a view. Every instance will be compiled into a SQL CREATE VIEW statement.
"""

def __init__(self, name: str, selectable: Select):
self.name = name
self.selectable = selectable


class DropView(ExecutableDDLElement):
"""
Class that is used to drop a view. Every instance will be compiled into a SQL DROP VIEW statement.
"""

def __init__(self, name: str):
self.name = name


@compiles(CreateView)
def _create_view(element: CreateView, compiler, **kw) -> str:
"""
Compiles a CreateView instance into a SQL CREATE VIEW statement.
:param element: CreateView instance
:param compiler: compiler
:param kw: keyword arguments
:return: SQL CREATE VIEW statement
"""

return "CREATE VIEW %s AS %s" % (
element.name,
compiler.sql_compiler.process(element.selectable, literal_binds=True),
)


@compiles(DropView)
def _drop_view(element: DropView, compiler, **kw) -> str:
"""
Compiles a DropView instance into a SQL DROP VIEW statement.
:param element: DropView instance
:param compiler: compiler
:param kw: keyword arguments
:return: SQL DROP VIEW statement
"""
return "DROP VIEW %s" % element.name


def view_exists(ddl: Union[CreateView, DropView], target, connection: engine, **kw) -> bool:
"""
Check if a view exists.
:param ddl: ddl instance
:param target: target object
:param connection: connection
:param kw: keyword arguments
:return: True if the view exists, False otherwise
"""

return ddl.name in inspect(connection).get_view_names()


def view_doesnt_exist(ddl: Union[CreateView, DropView], target, connection: engine, **kw) -> bool:
"""
Check if a view does not exist.
:param ddl: ddl instance
:param target: target object
:param connection: connection
:param kw: keyword arguments
:return: True if the view does not exist, False otherwise
"""

return not view_exists(ddl, target, connection, **kw)


def view(name: str, metadata: MetaData, selectable: Select) -> TableClause:
"""
Function used to control view creation and deletion. It will listen to the after_create and before_drop events
of the metadata object in order to either create or drop the view. The view needs to have a column id.
"""
view = table(name)

view._columns._populate_separate_keys(
col._make_proxy(view) for col in selectable.selected_columns
)

event.listen(metadata, "after_create", CreateView(name, selectable).execute_if(callable_=view_doesnt_exist))
event.listen(metadata, "before_drop", DropView(name).execute_if(callable_=view_exists))

return view


base = declarative_base(metadata=Base.metadata)


class PickUpWithContextView(base):
"""
View for pickup actions with context.
"""

__robot_position: Position = sqlalchemy.orm.aliased(Position, flat=True)
"""
3D Vector of robot position
"""

__robot_pose: Pose = sqlalchemy.orm.aliased(Pose, flat=True)
"""
Complete robot pose
"""

__object_position: Position = sqlalchemy.orm.aliased(Position, flat=True)
"""
3D Vector for object position
"""

__relative_x = (__robot_position.x - __object_position.x)
"""
Distance on x axis between robot and object
"""

__relative_y = (__robot_position.y - __object_position.y)
"""
Distance on y axis between robot and object
"""

__table__ = view("PickUpWithContextView", Base.metadata,
(select(PickUpAction.id.label("id"), PickUpAction.arm.label("arm"),
PickUpAction.grasp.label("grasp"), RobotState.torso_height.label("torso_height"),
__relative_x.label("relative_x"), __relative_y.label("relative_y"),
Quaternion.x.label("quaternion_x"), Quaternion.y.label("quaternion_y"),
Quaternion.z.label("quaternion_z"), Quaternion.w.label("quaternion_w"),
Object.obj_type.label("obj_type"), TaskTreeNode.status.label("status"))
.join(TaskTreeNode.action.of_type(PickUpAction))
.join(PickUpAction.robot_state)
.join(__robot_pose, RobotState.pose)
.join(__robot_position, __robot_pose.position)
.join(Pose.orientation)
.join(PickUpAction.object)
.join(Object.pose)
.join(__object_position, Pose.position)))
30 changes: 12 additions & 18 deletions src/pycram/resolver/location/database_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
import sqlalchemy.orm
import sqlalchemy.sql
from sqlalchemy import select, Select
from typing_extensions import List

from typing_extensions import List, Type
from ...costmaps import Rectangle, OccupancyCostmap
from ...designator import LocationDesignatorDescription
from ...designators.location_designator import CostmapLocation
from ...orm.action_designator import PickUpAction
from ...orm.base import RobotState, Quaternion
from ...orm.object_designator import Object
from ...orm.task import TaskTreeNode
from ...orm.views import PickUpWithContextView
from ...datastructures.pose import Pose
from ...orm.queries.queries import PickUpWithContext


@dataclass
Expand Down Expand Up @@ -71,34 +66,33 @@ def __init__(self, target, session: sqlalchemy.orm.Session = None, reachable_for
self.session = session

@staticmethod
def select_statement(query_context: PickUpWithContext) -> Select:
return query_context.join_statement(select(PickUpAction.arm, PickUpAction.grasp, RobotState.torso_height,
query_context.relative_x, query_context.relative_y, Quaternion.x,
Quaternion.y, Quaternion.z, Quaternion.w).distinct())
def select_statement(view: Type[PickUpWithContextView]) -> Select:
return (select(view.arm, view.grasp, view.torso_height, view.relative_x, view.relative_y, view.quaternion_x,
view.quaternion_y, view.quaternion_z, view.quaternion_w).distinct())

def create_query_from_occupancy_costmap(self) -> Select:
"""
Create a query that queries all relative robot positions from an object that are not occluded using an
OccupancyCostmap.
"""

query_context = PickUpWithContext()
view = PickUpWithContextView

# get query
query = self.select_statement(query_context)
query = self.select_statement(view)

# constraint query to correct object type and successful task status
query = query.where(Object.type == self.target.type).where(TaskTreeNode.status == "SUCCEEDED")
query = query.where(view.obj_type == self.target.obj_type).where(view.status == "SUCCEEDED")

filters = []

# for every rectangle
for rectangle in self.create_occupancy_rectangles():
# add sql filter
filters.append(sqlalchemy.and_(query_context.relative_x >= rectangle.x_lower,
query_context.relative_x < rectangle.x_upper,
query_context.relative_y >= rectangle.y_lower,
query_context.relative_y < rectangle.y_upper))
filters.append(sqlalchemy.and_(view.relative_x >= rectangle.x_lower,
view.relative_x < rectangle.x_upper,
view.relative_y >= rectangle.y_lower,
view.relative_y < rectangle.y_upper))

return query.where(sqlalchemy.or_(*filters))

Expand Down
10 changes: 4 additions & 6 deletions src/pycram/resolver/probabilistic/probabilistic_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from ...designators.actions.actions import MoveAndPickUpPerformable, ActionAbstract
from ...datastructures.enums import Arms, Grasp, TaskStatus
from ...local_transformer import LocalTransformer
from ...orm.queries.queries import PickUpWithContext
from ...orm.task import TaskTreeNode
from ...orm.action_designator import PickUpAction as ORMPickUpAction
from ...orm.views import PickUpWithContextView
from ...plan_failures import ObjectUnreachable, PlanFailure
from ...datastructures.pose import Pose

Expand Down Expand Up @@ -273,10 +272,9 @@ def iterate_without_occupancy_costmap(self) -> Iterator[MoveAndPickUpPerformable

@staticmethod
def query_for_database():
query_context = PickUpWithContext()
query = select(ORMPickUpAction.arm, ORMPickUpAction.grasp,
query_context.relative_x, query_context.relative_y)
query = query_context.join_statement(query).where(TaskTreeNode.status == TaskStatus.SUCCEEDED)
view = PickUpWithContextView
query = (select(view.arm, view.grasp, view.relative_x, view.relative_y)
.where(view.status == TaskStatus.SUCCEEDED))
return query

def batch_rollout(self):
Expand Down
Loading

0 comments on commit 9f0d35b

Please sign in to comment.