#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2000-2004 Free Software Foundation
#
# FILE:
# mysql/DBdriver.py
#
# DESCRIPTION:
# Driver to provide access to data vi MySQL
#
# NOTES:
# Supports transactions if the MySQL server is compiled w/transaction support
# (which it does NOT by default)


import string
import sys
import types
from gnue.common.datasources import GDataObjects, GConditions
from gnue.common.apps import GDebug
from gnue.common.datasources.drivers.DBSIG2.Driver \
   import DBSIG_RecordSet, DBSIG_ResultSet, DBSIG_DataObject, \
          DBSIG_DataObject_SQL, DBSIG_DataObject_Object

try:
  import MySQLdb
except ImportError, mesg:
  GDebug.printMesg(1,mesg)
  print "-"*79
  print _("\nCould not load MySQLdb.  For MySQL support, please install \n") \
      + _("mysql-python 0.9.0 or later from") \
      + "http://sourceforge.net/projects/mysql-python\n"
  print _("Error:  %s") % mesg
  print "-"*79
  sys.exit()



class MySQL_RecordSet(DBSIG_RecordSet):
  pass


class MySQL_ResultSet(DBSIG_ResultSet): 
  def __init__(self, dataObject, cursor=None, defaultValues={}, masterRecordSet=None): 
    DBSIG_ResultSet.__init__(self, dataObject, \
            cursor, defaultValues, masterRecordSet)
    self._recordSetClass = MySQL_RecordSet

    # Compensate for bug in python mysql drivers older than 0.9.2a2
    if MySQLdb.__version__ >= '0.9.2a2':
      self.fetchBugFix = self._cursor.fetchmany
    else:
      self.__done = 0
      self.fetchBugFix = self.__mySqlNeedsLotsOfTLC

  
  # Compensate for MySQ bug
  def __mySqlNeedsLotsOfTLC(self):
    if self.__done:
      return None

    self.__done = 1
    return self._cursor.fetchall()


  def _loadNextRecord(self):
    if self._cursor:
      rs = None

      try:
        # See __init__ for details
	rsets = self.fetchBugFix()

      except self._dataObject._DatabaseError, err:
        raise GDataObjects.ConnectionError, err

      if rsets and len(rsets):
        for rs in(rsets):
          if rs:
            i = 0
            dict = {}
            for f in (rs):
              if self._dataObject._unicodeMode and type(f)==types.StringType:
                f = unicode(f,self._dataObject._databaseEncoding)
                
              dict[string.lower(self._fieldNames[i])] = f
              i += 1
            self._cachedRecords.append (self._recordSetClass(parent=self, \
                                                             initialData=dict))
          else:
            return 0
        return 1
      else:
        return 0
    else:
     return 0

class MySQL_DataObject(DBSIG_DataObject):
  def __init__(self):
    DBSIG_DataObject.__init__(self)
    self._DatabaseError = MySQLdb.DatabaseError
    self._resultSetClass = MySQL_ResultSet
    
  def connect(self, connectData={}):
    GDebug.printMesg(1,"Mysql database driver initializing")
   
    # 1. just allow string type username/password 2. None -> ''
    user   = str(connectData['_username'] or '')
    passwd = str(connectData['_password'] or '')

    try:
      self._dataConnection = MySQLdb.connect(user=user,
                   passwd=passwd,
                   host=connectData['host'],
                   db=connectData['dbname'])
    except self._DatabaseError, value:
      raise GDataObjects.LoginError, value

    self._beginTransaction()
    self._postConnect()


  def _postConnect(self): 
    self.triggerExtensions = TriggerExtensions(self._dataConnection)


  def _beginTransaction(self):
    try:
      self._dataConnection.begin()
    except: 
      pass


  #
  # Schema (metadata) functions
  #

  # Return a list of the types of Schema objects this driver provides
  def getSchemaTypes(self):
    return [('table',_('Tables'),1)]

  # Return a list of Schema objects
  def getSchemaList(self, type=None):

    # TODO: This excludes any system tables and views. Should it?
    statement = "SHOW TABLES"

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    list = []
    for rs in cursor.fetchall():
      list.append(GDataObjects.Schema(attrs={'id':rs[0], 'name':rs[0],
                         'type':'table',
                         'primarykey': self.__getPrimaryKey(rs[0])},
                         getChildSchema=self.__getFieldSchema))

    cursor.close()
    return list


  # Find a schema object with specified name
  def getSchemaByName(self, name, type=None):
    statement = "DESCRIBE %s" % (name)

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    rs = cursor.fetchone()
    if rs:
      schema = GDataObjects.Schema(attrs={'id':name, 'name':name,
                           'type':'table',
                           'primarykey': self.__getPrimaryKey(name,cursor)},
                           getChildSchema=self.__getFieldSchema)
    else:
      schema = None

    cursor.close()
    return schema


  def __getPrimaryKey(self, id, cursor=None):
    statement = "DESCRIBE %s" % id
    if not cursor:
      cursor = self._dataConnection.cursor()
      close_cursor = 1
    else:
      close_cursor = 0
    cursor.execute(statement)

    lst = []
    for rs in cursor.fetchall():
      if rs[3] == 'PRI':
        lst.append(rs[0])

    if close_cursor:
      cursor.close()

    return tuple(lst)

  # Get fields for a table
  def __getFieldSchema(self, parent):

    statement = "DESCRIBE %s" % parent.id

    cursor = self._dataConnection.cursor()
    cursor.execute(statement)

    list = []
    for rs in cursor.fetchall():

      nativetype = string.split(string.replace(rs[1],')',''),'(')


      attrs={'id': "%s.%s" % (parent.id, rs[0]), 'name': rs[0],
             'type':'field', 'nativetype': nativetype[0],
             'required': rs[2] != 'YES'}

      if nativetype[0] in ('int','integer','bigint','mediumint',
                           'smallint','tinyint','float','real',
                           'double','decimal'):
        attrs['datatype']='number'
      elif nativetype[0] in ('date','time','timestamp','datetime'):
        attrs['datatype']='date'
      else:
        attrs['datatype']='text'

      try:
        if len(nativetype) == 2:
          try:
            ln, prec = nativetype[1].split(',')
          except:
            ln = nativetype[1]
            prec = None
          attrs['length'] = int(ln.split()[0])
          if prec != None:
            attrs['precision'] = int(prec)
      except ValueError:
        GDebug.printMesg(1,'WARNING: mysql native type error: %s' % nativetype)

      if rs[4] not in ('NULL', '0000-00-00 00:00:00','', None):
        attrs['defaulttype'] = 'constant'
        attrs['defaultval'] = rs[4]

      if rs[5] == 'auto_increment':
        attrs['defaulttype'] = 'serial'


      list.append(GDataObjects.Schema(attrs=attrs))

    cursor.close()
    return list




class MySQL_DataObject_Object(MySQL_DataObject, \
      DBSIG_DataObject_Object):

  def __init__(self):
    MySQL_DataObject.__init__(self)

  def _buildQuery(self, conditions={},forDetail=None,additionalSQL=""):
    return DBSIG_DataObject_Object._buildQuery(self, conditions,forDetail,additionalSQL)


class MySQL_DataObject_SQL(MySQL_DataObject, \
      DBSIG_DataObject_SQL):
  def __init__(self):
    # Call DBSIG init first because MySQL_DataObject needs to overwrite
    # some of its values
    DBSIG_DataObject_SQL.__init__(self)
    MySQL_DataObject.__init__(self)

  def _buildQuery(self, conditions={}):
    return DBSIG_DataObject_SQL._buildQuery(self, conditions)


#
#  Extensions to Trigger Namespaces
#  
class TriggerExtensions: 

  def __init__(self, connection): 
    self.__connection = connection

  # Return the current date, according to database
#  def getDate(self): 
#    pass

  # Return a sequence number from sequence 'name' 
#  def getSequence(self, name): 
#    pass

  # Run the SQL statement 'statement'
#  def sql(self, statement): 
#    pass



######################################
#
#  The following hashes describe 
#  this driver's characteristings.
#
######################################

#
#  All datasouce "types" and corresponding DataObject class
# 
supportedDataObjects = {
  'object': MySQL_DataObject_Object,
  'sql':    MySQL_DataObject_SQL
}


