#-------------------------------------------------------------------------------
#
#  Define a gridded canvas for doing drawing object layout.
#
#  Written by: David C. Morrill
#
#  Date: 11/09/2003
#
#  (c) Copyright 2003 by Enthought, Inc.
#
#  Classes defined: GriddedCanvas
#
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
#  Imports:
#-------------------------------------------------------------------------------

from enthought.util.numerix import array

from enthought.enable.base          import transparent_color, xy_in_bounds, \
                                    add_rectangles, green_color, yellow_color, \
                                    black_color, bounds_to_coordinates
from enthought.enable.component     import Component
from enthought.enable.container     import AbstractContainer, Container
from enthought.enable.frame         import Frame
from enthought.enable.enable_traits import white_color_trait, red_color_trait, \
                                    grid_trait, border_size_trait, \
                                    layout_style_trait, selection_state_trait
                                    
from enthought.traits.api               import Trait, TraitList, Event, \
                                           true, false
from enthought.traits.ui.api            import Group, View, Include

#-------------------------------------------------------------------------------
#  Constants:
#-------------------------------------------------------------------------------

vertical_guideline_rect   = ( -3.0,  0.0, 6.0, 0.0 )
horizontal_guideline_rect = (  0.0, -3.0, 0.0, 6.0 )
selection_rect            = ( -3.0, -3.0, 6.0, 6.0 )
selection_rect_draw       = ( -4.0, -4.0, 8.0, 8.0 )
selection_rect_drag       = ( -4.0, -4.0, 3.0, 3.0 )

selection_colors          = ( yellow_color, green_color )

selection_pointer = [ 'arrow', 
                      'size top left',    'size top',    'size top right',
                      'size left',        'hand',        'size right',
                      'size bottom left', 'size bottom', 'size bottom right' ]

resize_factors = (
    ( 0, 0,  0,  0 ),
    ( 1, 0, -1,  1 ),
    ( 0, 0,  0,  1 ),
    ( 0, 0,  1,  1 ),
    ( 1, 0, -1,  0 ),
    ( 0, 0,  0,  0 ),
    ( 0, 0,  1,  0 ),
    ( 1, 1, -1, -1 ),
    ( 0, 1,  0, -1 ),
    ( 0, 1,  1, -1 )
)    
                      
short_dash_line_style = array( [ 3.0, 3.0 ] )

#-------------------------------------------------------------------------------
#  'GuideLine' class:
#-------------------------------------------------------------------------------
        
class GuideLine ( Component ): 
    
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    color = red_color_trait
    style = layout_style_trait

    #---------------------------------------------------------------------------
    #  Trait view definitions:
    #---------------------------------------------------------------------------
    
    traits_view = View( Group( '<component>', id = 'component' ),
                        Group( '<links>',     id = 'links' ),
                        Group( 'color', '_', 'style', 
                               id    = 'guideline',
                               style = 'custom' ) )
    
    colorchip_map = {
        'fg_color': 'color',
    }
    
    #---------------------------------------------------------------------------
    #  Draw the component in a specified graphics context:
    #---------------------------------------------------------------------------
    
    def _draw ( self, gc ):
        gc.save_state()
        gc.set_stroke_color( self.color_ )
        gc.set_line_width( 1.0 )
        gc.begin_path()
        x, y, dx, dy = self.bounds
        if self.style[0] == 'v':
            x += 0.5
            gc.move_to( x, y )
            gc.line_to( x, y + dy )
        else:
            y += 0.5
            gc.move_to( x, y )
            gc.line_to( x + dx, y )
        gc.stroke_path()
        gc.restore_state()

    #---------------------------------------------------------------------------
    #  Make the guideline fit correctly within its container: 
    #---------------------------------------------------------------------------
    
    def _check_bounds ( self ):
        x, y, dx, dy = self.container.bounds
        if self.style[0] == 'v':
            bounds = ( max( self.x, x ), y, 1.0, dy )
        else:
            bounds = ( x, max( self.y, y ), dx, 1.0 )
        try:
            bounds = self.container._check_snap( self, bounds )
        except:
            pass
        self.bounds = bounds
        
    #---------------------------------------------------------------------------
    #  Determine whether a point is over the guideline: 
    #---------------------------------------------------------------------------
    
    def _is_over ( self, x, y ):
        if self.style[0] == 'v':
            rect = vertical_guideline_rect
        else:
            rect = horizontal_guideline_rect
        return xy_in_bounds( x, y, add_rectangles( self.bounds, rect ) )
                       
    #---------------------------------------------------------------------------
    #  Return the components that contain a specified (x,y) point:
    #---------------------------------------------------------------------------
       
    def components_at ( self, x, y ):
        if self.visible and self._drawable and self._is_over( x, y ): 
            return [ self ]
        return []
        
    #---------------------------------------------------------------------------
    #  Handle the guideline being dropped: 
    #---------------------------------------------------------------------------
    
    def _dropped_changed ( self, event ):
        x, y, dx, dy = self.bounds
        if self.style[0] == 'v':
            x = event.x
        else:
            y = event.y
        bounds = ( x, y, dx, dy )
        try:
            bounds = self.container._check_snap( self, bounds )
        except:
            pass
        self.bounds  = bounds
        self.pointer = 'arrow'
        
    #---------------------------------------------------------------------------
    #  Handle the guideline being dragged: 
    #---------------------------------------------------------------------------
    
    def _drag_handler ( self, event, bounds ):
        try:
            return self.container._check_snap( self, bounds )
        except:
            return bounds
    
    #---------------------------------------------------------------------------
    #  Handle mouse events: 
    #---------------------------------------------------------------------------
    
    def _left_down_changed ( self, event ):
        self.window.mouse_owner = None
        self.window.drag( self, self.container, event, 
                          drag_handler = self._drag_handler, alpha = -1.0 )
        
    def _mouse_move_changed ( self, event ):
        event.handled = True
        if self._is_over( event.x, event.y ):
            self.pointer = [ 'size top', 'size left' ][ self.style[0] == 'v' ]
            self.window.mouse_owner = self
        else:
            self.pointer = 'arrow'
            self.window.mouse_owner = None
            
#-------------------------------------------------------------------------------
#  'SelectionFrame' class
#-------------------------------------------------------------------------------
            
class SelectionFrame ( Frame ):
    
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    state = selection_state_trait
    
    #---------------------------------------------------------------------------
    #  Handle the state being changed:
    #---------------------------------------------------------------------------
    
    def _state_changed ( self ):
        self.redraw( add_rectangles( self.bounds, selection_rect_draw ) )
        
    #---------------------------------------------------------------------------
    #  Handle the bounds of the component being changed:
    #---------------------------------------------------------------------------
        
    def _bounds_changed ( self, old, new ):
        Frame._bounds_changed( self, old, new )
        self.redraw( add_rectangles( old, selection_rect_draw ) )
        self.redraw( add_rectangles( new, selection_rect_draw ) )
        
    #---------------------------------------------------------------------------
    #  Do any drawing that needs to be done after drawing the contained
    #  component:
    #---------------------------------------------------------------------------
    
    def _post_draw ( self, gc ):
        state = self.state[0]
        if (state == 'u') or self.container.test_mode:
            return
            
        gc.save_state()
        
        # Draw the outer selection rectangle:
        xb, yb, dx, dy = self.bounds
        color          = selection_colors[ state == 's' ]
        
        gc.set_line_width( 1 )
        gc.set_stroke_color( color )
 
        gc.begin_path()
        gc.rect( xb - 0.5, yb - 0.5, dx, dy )
        gc.stroke_path()
        
        gc.set_stroke_color( black_color )
        gc.set_line_dash( short_dash_line_style )
        
        gc.begin_path()
        gc.rect( xb - 0.5, yb - 0.5, dx, dy )
        gc.stroke_path()
 
        gc.set_fill_color( color )
        gc.set_line_dash( None )
           
        # Draw each of the eight selection boxes:
        xm = xb + round( dx / 2.0 )
        ym = yb + round( dy / 2.0 )
        ys = ( yb, ym, yb + dy )
        for x in [ xb , xm, xb + dx ]:
            for y in ys:
                if (x != xm) or (y != ym):
                   gc.begin_path()
                   gc.rect( x - 3.5, y - 3.5, 6.0, 6.0 )
                   gc.draw_path()
     
        gc.restore_state()
        
    #---------------------------------------------------------------------------
    #  Get what region (if any) of the selection a specified point is in: 
    #---------------------------------------------------------------------------
    
    def _get_zone ( self, x, y = None ):
        if y is None:
            y = x.y
            x = x.x
            
        bounds = self.bounds
        if not xy_in_bounds( x, y, add_rectangles( bounds, selection_rect ) ):
            return 0
            
        xl, yb, xr, yt = bounds_to_coordinates( bounds )
        
        if (xl - 3.0) <= x < (xl + 3.0):
            col = 1
        elif (xr - 3.0) <= x < (xr + 3.0):
            col = 3
        else:
            xm = xl + (self.width / 2.0)
            if not ((xm - 3.0) <= x < (xm + 3.0)):
                return xy_in_bounds( x, y, bounds ) * 5
            col = 2
        
        if (yb - 3.0) <= y < (yb + 3.0):
            row = 6
        elif (yt - 3.0) <= y < (yt + 3.0):
            row = 0
        else:
            ym = yb + (self.height / 2.0)
            if not ((ym - 3.0) <= y < (ym + 3.0)):
                return xy_in_bounds( x, y, bounds ) * 5
            row = 3
                
        return row + col

    #---------------------------------------------------------------------------
    #  Make the selection fit correctly within its container: 
    #---------------------------------------------------------------------------
    
    def _check_bounds ( self ):
        pass
        
    #---------------------------------------------------------------------------
    #  Handle the selection being dragged: 
    #---------------------------------------------------------------------------
    
    def _drag_handler ( self, event, bounds ):
        try:
            return self.container._check_snap( self, bounds )
        except:
            return bounds
        
    #---------------------------------------------------------------------------
    #  Return the components that contain a specified (x,y) point:
    #---------------------------------------------------------------------------
       
    def components_at ( self, x, y ):
        if self.container.test_mode:
            return self.component.components_at( x, y )
        if self.visible and self._drawable and (self._get_zone( x, y ) != 0): 
            return [ self ]
        return []
        
    #---------------------------------------------------------------------------
    #  Handle mouse events for the selection: 
    #---------------------------------------------------------------------------
    
    def _left_down_changed ( self, event ):
        event.handled = True
        zone          = self._get_zone( event )
        if zone == 0:
            return
        self.window.mouse_owner = self
        self._drag_start        = ( event.x, event.y )
        state                   = self.state[0]
        if (state != 'u') and (zone != 5):
            self._dragging = zone
        
    def _left_up_changed ( self, event ):
        event.handled    = True
        self._drag_start = None
        self.window.mouse_owner = None
        state     = self.state[0]
        zone      = self._get_zone( event )
        container = self.container
        if state == 'u':
            if zone != 0:
                if event.control_down or event.shift_down:
                    container.add_selection( self )
                else:
                    container.set_selection( self )
        elif self._dragging is None:
            if zone == 5:
                if (not event.control_down) and (not event.shift_down):
                    container.set_selection( self )
                elif state == 'c':
                    container.select( self )
                elif event.control_down:
                    container.remove_selection( self )
                    self.pointer = 'arrow'
        else:
            self._dragging = None
            
    def _right_up_changed ( self, event ):
        event.handled = True
        if self._get_zone( event ) != 0:
            components = self.component.components_at( event.x, event.y )
            if len( components ) > 0:
                self.container.component_context = components[-1]
        
    def _mouse_move_changed ( self, event ):
        event.handled = True
        if self._dragging is not None:
            x, y             = self._drag_start
            dx, dy           = (event.x - x), (event.y - y)
            cx, cy, cdx, cdy = self.bounds
            fx, fy, fdx, fdy = resize_factors[ self._dragging ]
            ncdx = cdx + fdx * dx
            if ncdx < self.min_width:
                ncdx = self.min_width
                dx   = (self.min_width - cdx) / fdx
            ncdy = cdy + fdy * dy
            if ncdy < self.min_height:
                ncdy = self.min_height
                dy   = (self.min_height - cdy) / fdy
            self._drag_start = ( x + dx, y + dy )
            self.bounds      = self._drag_handler( None, 
                                    ( cx + fx * dx, cy + fy * dy, ncdx, ncdy ) )
            self._state_changed()
        elif self._drag_start is not None:
            x, y = self._drag_start
            if (abs( x - event.x ) > 2) or (abs( y - event.y ) > 2):
                self._drag_start        = None 
                self.window.mouse_owner = None
                self.pointer            = 'arrow'
                container               = self.container
                if self.state[0] == 'u':
                    self.window.drag( self, container, event, 
                              drag_handler = self._drag_handler, alpha = -1.0 )
                else:
                    self.window.drag( container.selection[:], container, event, 
                              drag_handler = self._drag_handler, 
                              alpha = -1.0, inset = selection_rect_drag )
        elif self.state[0] != 'u':
            zone = self._get_zone( event )
            self.pointer = selection_pointer[ zone ]
            self.window.mouse_owner = [ None, self ][ zone != 0 ]
            
    #---------------------------------------------------------------------------
    #  Handle drag events: 
    #---------------------------------------------------------------------------
    
    def _dropped_changed ( self, event ):
        self.location( event.x, event.y )
                          
    #---------------------------------------------------------------------------
    #  Handle a ComponentFactory object being dropped on the selection: 
    #---------------------------------------------------------------------------
    
    def dropped_on_by_componentfactory ( self, factory, event ):
        event.handled = True
        if isinstance( self.component, AbstractContainer ):
            component = factory.create_component()
            component.location( factory.component.x + event.x - event.x0,
                                factory.component.y + event.y - event.y0 )
            self.component.add( component )

    #---------------------------------------------------------------------------
    #  Handle a ComponentFactoryNode object being dropped on the selection:
    #
    #  HACK: Allows Envisage nodes containing a component factory to be dropped 
    #---------------------------------------------------------------------------

    def drag_over_by_componentfactorynode ( self, node, event ):
        event.handled = True

    def drag_leave_by_componentfactorynode ( self, node, event ):
        event.handled = True

    def dropped_on_by_componentfactorynode ( self, factory_node, event ):
        event.handled = True
        if isinstance( self.component, AbstractContainer ):
            component = factory_node.data.create_component()
            component.location( event.x, event.y )
            self.component.add( component )

#-------------------------------------------------------------------------------
#  'GriddedCanvas' class:
#-------------------------------------------------------------------------------

class GriddedCanvas ( Container ):
    
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    bg_color          = Trait( ( .925, .914, .847, 1.0 ), white_color_trait )
    grid_color        = Trait( ( .760, .840, .920, 1.0 ), white_color_trait )
    grid_width        = grid_trait
    grid_height       = grid_trait
    grid_size         = Trait( 1, border_size_trait )
    grid_visible      = true
    snap_to_grid      = false
    snap_to_guide     = false
    test_mode         = false
    selection         = TraitList( SelectionFrame )
    component_context = Event( Component )

    #---------------------------------------------------------------------------
    #  Trait view definitions:
    #---------------------------------------------------------------------------
    
    traits_view = View( Group( '<component>', 
                               'grid_visible', 'snap_to_grid',
                               'snap_to_guide', 'test_mode',
                               id = 'component' ),
                        Group( '<links>', id = 'links' ),
                        Group( 'bg_color{Background color}', '_',
                               'grid_color', 
                               id    = 'color', 
                               style = 'custom' ),
                        Group( 'grid_width', 'grid_height', '_', 'grid_size', 
                               id    = 'size', 
                               style = 'custom' ) )
    
    colorchip_map = {
        'fg_color': 'grid_color',
        'bg_color': 'bg_color'
    }
    
    #---------------------------------------------------------------------------
    #  Handle a GuideLine being added or removed from the canvas: 
    #---------------------------------------------------------------------------
    
    def _add ( self, component ):
        if isinstance( component, GuideLine ):
            if self._guidelines is None:
                self._guidelines = []
            self._guidelines.append( component )
            component._check_bounds()
            
    def _remove ( self, component ):
        if isinstance( component, GuideLine ):
            try:
                self._guidelines.remove( component )
            except:
                pass
                
    #---------------------------------------------------------------------------
    #  Handle the bounds being changed: 
    #---------------------------------------------------------------------------
    
    def _check_bounds ( self, component ):
        component._check_bounds()
                
    #---------------------------------------------------------------------------
    #  Verify that the suggested bounds for a component match the current 
    #  snap to grid mode and grid size. If not, adjust them accordingly: 
    #---------------------------------------------------------------------------
    
    def _check_snap ( self, component, bounds ):
        if not self.snap_to_grid:
            return bounds
        x, y, dx, dy     = self.bounds
        cx, cy, cdx, cdy = bounds
        gdx = self.grid_width
        if gdx > 1:
            cx  = x + gdx * round( (cx - x) / gdx )
            if not isinstance( component, GuideLine ):
                cdx = gdx * round( cdx / gdx )
        gdy = self.grid_height
        if gdy > 1:
            cy  = y + gdy * round( (cy - y) / gdy )
            if not isinstance( component, GuideLine ):
                cdy = gdy * round( cdy / gdy )
        return ( cx, cy, cdx, cdy )

    #---------------------------------------------------------------------------
    #  Draw the container background in a specified graphics context:
    #  (This method should normally be overridden by a subclass)
    #---------------------------------------------------------------------------
    
    def _draw_container ( self, gc ):
        gc.save_state()
        
        x, y, dx, dy = self.bounds
        
        # Fill the background region (if required);
        bg_color = self.bg_color_
        if bg_color is not transparent_color:
            gc.set_fill_color( bg_color )
            gc.begin_path()
            gc.rect( x, y, dx, dy ) 
            gc.fill_path()
            
        # Draw the grid (if required):
        if self.grid_visible:
            gs = self.grid_size
            if gs > 0:
                gsh = gs / 2.0
                gc.set_stroke_color( self.grid_color_ )
                gc.set_line_width( gs )
                gc.begin_path()
                yb  = y + gsh
                yt  = y + dy - gsh
                xl  = xc = x + gsh
                xr  = x + dx - gsh
                gdx = self.grid_width
                if gdx > 0:
                    while xc < xr:
                        gc.move_to( xc, yb )
                        gc.line_to( xc, yt )
                        xc += gdx
                gdy = self.grid_height
                if gdy > 0:
                    while yb < yt:
                        gc.move_to( xl, yb )
                        gc.line_to( xr, yb )
                        yb += gdy
                gc.stroke_path()

        gc.restore_state()
        
    #---------------------------------------------------------------------------
    #  Selection management: 
    #---------------------------------------------------------------------------
    
    def clear_selection ( self ):
        for item in self.selection:
            item.state = 'unselected'
        self.selection = []
        
    def set_selection ( self, selection ):
        for item in self.selection:
            item.state = 'unselected'
        self.selection  = [ selection ]
        selection.state = 'selected'
        
    def add_selection ( self, selection ):
        self.selection.append( selection )
        selection.state = [ 'selected', 'coselected' ][ 
                          len( self.selection ) > 1 ]
                          
    def remove_selection ( self, selection ):
        try:
            items = self.selection
            index = items.index( selection )
            del items[ index ]
            selection.state = 'unselected'
            if (index == 0) and (len( items ) > 0):
                items[0].state = 'selected'
        except:
            pass
                          
    def select ( self, selection ):
        try:
            index = self.selection.index( selection )
            if index == 0:
                return
            del self.selection[ index ]
        except:
            pass
        self.selection[0:0] = [ selection ]
        selection.state = 'selected'
        if len( self.selection ) > 1:
            self.selection[1].state = 'coselected'
            
    def selection_empty ( self ):
        return (len( self.selection ) == 0)
        
    #---------------------------------------------------------------------------
    #  Handle mouse events:
    #---------------------------------------------------------------------------
    
    def _left_up_changed ( self, event ):
        event.handled = True
        self.clear_selection()
            
    def _right_up_changed ( self, event ):
        event.handled = True
        self.component_context = self
                          
    #---------------------------------------------------------------------------
    #  Handle a ComponentFactory object being dropped on the canvas: 
    #---------------------------------------------------------------------------
    
    def dropped_on_by_componentfactory ( self, factory, event ):
        event.handled = True
        component     = SelectionFrame( factory.create_component() )
        component.location( factory.component.x + event.x - event.x0,
                            factory.component.y + event.y - event.y0 )
        self.add( component )
        self.set_selection( component )

    #---------------------------------------------------------------------------
    #  Handle a ComponentFactoryNode object being dropped on the canvas:
    #
    #  HACK: Allows Envisage nodes containing a component factory to be dropped
    #---------------------------------------------------------------------------

    def drag_over_by_componentfactorynode ( self, node, event ):
        event.handled = True

    def drag_leave_by_componentfactorynode ( self, node, event ):
        event.handled = True

    def dropped_on_by_componentfactorynode ( self, factory_node, event ):
        event.handled = True
        component     = SelectionFrame( factory_node.data.create_component() )
        component.location( event.x, event.y )
        self.add( component )
        self.set_selection( component )
