@ -18,21 +18,6 @@ import (
"github.com/go-xorm/core"
"github.com/go-xorm/core"
)
)
type incrParam struct {
colName string
arg interface { }
}
type decrParam struct {
colName string
arg interface { }
}
type exprParam struct {
colName string
expr string
}
// Statement save all the sql info for executing SQL
// Statement save all the sql info for executing SQL
type Statement struct {
type Statement struct {
RefTable * core . Table
RefTable * core . Table
@ -47,7 +32,6 @@ type Statement struct {
HavingStr string
HavingStr string
ColumnStr string
ColumnStr string
selectStr string
selectStr string
columnMap map [ string ] bool
useAllCols bool
useAllCols bool
OmitStr string
OmitStr string
AltTableName string
AltTableName string
@ -67,6 +51,8 @@ type Statement struct {
allUseBool bool
allUseBool bool
checkVersion bool
checkVersion bool
unscoped bool
unscoped bool
columnMap columnMap
omitColumnMap columnMap
mustColumnMap map [ string ] bool
mustColumnMap map [ string ] bool
nullableMap map [ string ] bool
nullableMap map [ string ] bool
incrColumns map [ string ] incrParam
incrColumns map [ string ] incrParam
@ -89,7 +75,8 @@ func (statement *Statement) Init() {
statement . HavingStr = ""
statement . HavingStr = ""
statement . ColumnStr = ""
statement . ColumnStr = ""
statement . OmitStr = ""
statement . OmitStr = ""
statement . columnMap = make ( map [ string ] bool )
statement . columnMap = columnMap { }
statement . omitColumnMap = columnMap { }
statement . AltTableName = ""
statement . AltTableName = ""
statement . tableName = ""
statement . tableName = ""
statement . idParam = nil
statement . idParam = nil
@ -221,34 +208,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil {
if err != nil {
return err
return err
}
}
statement . tableName = statement . Engine . tb Name( v )
statement . tableName = statement . Engine . Table Name( v , true )
return nil
return nil
}
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func ( statement * Statement ) setRefBean ( bean interface { } ) error {
func ( statement * Statement ) Table ( tableNameOrBean interface { } ) * Statement {
v := rValue ( tableNameOrBean )
t := v . Type ( )
if t . Kind ( ) == reflect . String {
statement . AltTableName = tableNameOrBean . ( string )
} else if t . Kind ( ) == reflect . Struct {
var err error
var err error
statement . RefTable , err = statement . Engine . autoMapType ( v )
statement . RefTable , err = statement . Engine . autoMapType ( rValue ( bean ) )
if err != nil {
if err != nil {
statement . Engine . logger . Error ( err )
return err
return statement
}
statement . AltTableName = statement . Engine . tbName ( v )
}
}
return statement
statement . tableName = statement . Engine . TableName ( bean , true )
return nil
}
}
// Auto generating update columnes and values according a struct
// Auto generating update columnes and values according a struct
func buildUpdates ( engine * Engine , table * core . Table , bean interface { } ,
func ( statement * Statement ) buildUpdates ( bean interface { } ,
includeVersion bool , includeUpdated bool , includeNil bool ,
includeVersion , includeUpdated , includeNil ,
includeAutoIncr bool , allUseBool bool , useAllCols bool ,
includeAutoIncr , update bool ) ( [ ] string , [ ] interface { } ) {
mustColumnMap map [ string ] bool , nullableMap map [ string ] bool ,
engine := statement . Engine
columnMap map [ string ] bool , update , unscoped bool ) ( [ ] string , [ ] interface { } ) {
table := statement . RefTable
allUseBool := statement . allUseBool
useAllCols := statement . useAllCols
mustColumnMap := statement . mustColumnMap
nullableMap := statement . nullableMap
columnMap := statement . columnMap
omitColumnMap := statement . omitColumnMap
unscoped := statement . unscoped
var colNames = make ( [ ] string , 0 )
var colNames = make ( [ ] string , 0 )
var args = make ( [ ] interface { } , 0 )
var args = make ( [ ] interface { } , 0 )
@ -268,7 +254,10 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if col . IsDeleted && ! unscoped {
if col . IsDeleted && ! unscoped {
continue
continue
}
}
if use , ok := columnMap [ strings . ToLower ( col . Name ) ] ; ok && ! use {
if omitColumnMap . contain ( col . Name ) {
continue
}
if len ( columnMap ) > 0 && ! columnMap . contain ( col . Name ) {
continue
continue
}
}
@ -604,17 +593,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
}
}
func ( statement * Statement ) colmap2NewColsWithQuote ( ) [ ] string {
func ( statement * Statement ) colmap2NewColsWithQuote ( ) [ ] string {
newColumns := make ( [ ] string , 0 , len ( statement . columnMap ) )
newColumns := make ( [ ] string , len ( statement . columnMap ) , len ( statement . columnMap ) )
for col := range statement . columnMap {
copy ( newColumns , statement . columnMap )
fields := strings . Split ( strings . TrimSpace ( col ) , "." )
for i := 0 ; i < len ( statement . columnMap ) ; i ++ {
if len ( fields ) == 1 {
newColumns [ i ] = statement . Engine . Quote ( newColumns [ i ] )
newColumns = append ( newColumns , statement . Engine . quote ( fields [ 0 ] ) )
} else if len ( fields ) == 2 {
newColumns = append ( newColumns , statement . Engine . quote ( fields [ 0 ] ) + "." +
statement . Engine . quote ( fields [ 1 ] ) )
} else {
panic ( errors . New ( "unwanted colnames" ) )
}
}
}
return newColumns
return newColumns
}
}
@ -642,10 +624,11 @@ func (statement *Statement) Select(str string) *Statement {
func ( statement * Statement ) Cols ( columns ... string ) * Statement {
func ( statement * Statement ) Cols ( columns ... string ) * Statement {
cols := col2NewCols ( columns ... )
cols := col2NewCols ( columns ... )
for _ , nc := range cols {
for _ , nc := range cols {
statement . columnMap [ strings . ToLower ( nc ) ] = true
statement . columnMap . add ( nc )
}
}
newColumns := statement . colmap2NewColsWithQuote ( )
newColumns := statement . colmap2NewColsWithQuote ( )
statement . ColumnStr = strings . Join ( newColumns , ", " )
statement . ColumnStr = strings . Join ( newColumns , ", " )
statement . ColumnStr = strings . Replace ( statement . ColumnStr , statement . Engine . quote ( "*" ) , "*" , - 1 )
statement . ColumnStr = strings . Replace ( statement . ColumnStr , statement . Engine . quote ( "*" ) , "*" , - 1 )
return statement
return statement
@ -680,7 +663,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
func ( statement * Statement ) Omit ( columns ... string ) {
func ( statement * Statement ) Omit ( columns ... string ) {
newColumns := col2NewCols ( columns ... )
newColumns := col2NewCols ( columns ... )
for _ , nc := range newColumns {
for _ , nc := range newColumns {
statement . columnMap [ strings . ToLower ( nc ) ] = false
statement . omitColumnMap = append ( statement . omitColumnMap , nc )
}
}
statement . OmitStr = statement . Engine . Quote ( strings . Join ( newColumns , statement . Engine . Quote ( ", " ) ) )
statement . OmitStr = statement . Engine . Quote ( strings . Join ( newColumns , statement . Engine . Quote ( ", " ) ) )
}
}
@ -743,6 +726,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement
return statement
}
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func ( statement * Statement ) Table ( tableNameOrBean interface { } ) * Statement {
v := rValue ( tableNameOrBean )
t := v . Type ( )
if t . Kind ( ) == reflect . Struct {
var err error
statement . RefTable , err = statement . Engine . autoMapType ( v )
if err != nil {
statement . Engine . logger . Error ( err )
return statement
}
}
statement . AltTableName = statement . Engine . TableName ( tableNameOrBean , true )
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func ( statement * Statement ) Join ( joinOP string , tablename interface { } , condition string , args ... interface { } ) * Statement {
func ( statement * Statement ) Join ( joinOP string , tablename interface { } , condition string , args ... interface { } ) * Statement {
var buf bytes . Buffer
var buf bytes . Buffer
@ -752,39 +752,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt . Fprintf ( & buf , "%v JOIN " , joinOP )
fmt . Fprintf ( & buf , "%v JOIN " , joinOP )
}
}
switch tablename . ( type ) {
tbName := statement . Engine . TableName ( tablename , true )
case [ ] string :
t := tablename . ( [ ] string )
if len ( t ) > 1 {
fmt . Fprintf ( & buf , "%v AS %v" , statement . Engine . Quote ( t [ 0 ] ) , statement . Engine . Quote ( t [ 1 ] ) )
} else if len ( t ) == 1 {
fmt . Fprintf ( & buf , statement . Engine . Quote ( t [ 0 ] ) )
}
case [ ] interface { } :
t := tablename . ( [ ] interface { } )
l := len ( t )
var table string
if l > 0 {
f := t [ 0 ]
v := rValue ( f )
t := v . Type ( )
if t . Kind ( ) == reflect . String {
table = f . ( string )
} else if t . Kind ( ) == reflect . Struct {
table = statement . Engine . tbName ( v )
}
}
if l > 1 {
fmt . Fprintf ( & buf , "%v AS %v" , statement . Engine . Quote ( table ) ,
statement . Engine . Quote ( fmt . Sprintf ( "%v" , t [ 1 ] ) ) )
} else if l == 1 {
fmt . Fprintf ( & buf , statement . Engine . Quote ( table ) )
}
default :
fmt . Fprintf ( & buf , statement . Engine . Quote ( fmt . Sprintf ( "%v" , tablename ) ) )
}
fmt . Fprintf ( & buf , " ON %v" , condition )
fmt . Fprintf ( & buf , "%s ON %v" , tbName , condition )
statement . JoinStr = buf . String ( )
statement . JoinStr = buf . String ( )
statement . joinArgs = append ( statement . joinArgs , args ... )
statement . joinArgs = append ( statement . joinArgs , args ... )
return statement
return statement
@ -817,10 +787,12 @@ func (statement *Statement) genColumnStr() string {
columns := statement . RefTable . Columns ( )
columns := statement . RefTable . Columns ( )
for _ , col := range columns {
for _ , col := range columns {
if statement . OmitStr != "" {
if statement . omitColumnMap . contain ( col . Name ) {
if _ , ok := getFlagForColumn ( statement . columnMap , col ) ; ok {
continue
continue
}
}
if len ( statement . columnMap ) > 0 && ! statement . columnMap . contain ( col . Name ) {
continue
}
}
if col . MapType == core . ONLYTODB {
if col . MapType == core . ONLYTODB {
@ -831,10 +803,6 @@ func (statement *Statement) genColumnStr() string {
buf . WriteString ( ", " )
buf . WriteString ( ", " )
}
}
if col . IsPrimaryKey && statement . Engine . Dialect ( ) . DBType ( ) == "ql" {
buf . WriteString ( "id() AS " )
}
if statement . JoinStr != "" {
if statement . JoinStr != "" {
if statement . TableAlias != "" {
if statement . TableAlias != "" {
buf . WriteString ( statement . TableAlias )
buf . WriteString ( statement . TableAlias )
@ -859,11 +827,13 @@ func (statement *Statement) genCreateTableSQL() string {
func ( statement * Statement ) genIndexSQL ( ) [ ] string {
func ( statement * Statement ) genIndexSQL ( ) [ ] string {
var sqls [ ] string
var sqls [ ] string
tbName := statement . TableName ( )
tbName := statement . TableName ( )
quote := statement . Engine . Quote
for _ , index := range statement . RefTable . Indexes {
for idxName , index := range statement . RefTable . Indexes {
if index . Type == core . IndexType {
if index . Type == core . IndexType {
sql := fmt . Sprintf ( "CREATE INDEX %v ON %v (%v);" , quote ( indexName ( tbName , idxName ) ) ,
sql := statement . Engine . dialect . CreateIndexSql ( tbName , index )
quote ( tbName ) , quote ( strings . Join ( index . Cols , quote ( "," ) ) ) )
/ * idxTBName := strings . Replace ( tbName , "." , "_" , - 1 )
idxTBName = strings . Replace ( idxTBName , ` " ` , "" , - 1 )
sql := fmt . Sprintf ( "CREATE INDEX %v ON %v (%v);" , quote ( indexName ( idxTBName , idxName ) ) ,
quote ( tbName ) , quote ( strings . Join ( index . Cols , quote ( "," ) ) ) ) * /
sqls = append ( sqls , sql )
sqls = append ( sqls , sql )
}
}
}
}
@ -889,16 +859,18 @@ func (statement *Statement) genUniqueSQL() []string {
func ( statement * Statement ) genDelIndexSQL ( ) [ ] string {
func ( statement * Statement ) genDelIndexSQL ( ) [ ] string {
var sqls [ ] string
var sqls [ ] string
tbName := statement . TableName ( )
tbName := statement . TableName ( )
idxPrefixName := strings . Replace ( tbName , ` " ` , "" , - 1 )
idxPrefixName = strings . Replace ( idxPrefixName , ` . ` , "_" , - 1 )
for idxName , index := range statement . RefTable . Indexes {
for idxName , index := range statement . RefTable . Indexes {
var rIdxName string
var rIdxName string
if index . Type == core . UniqueType {
if index . Type == core . UniqueType {
rIdxName = uniqueName ( tb Name, idxName )
rIdxName = uniqueName ( idxPrefix Name, idxName )
} else if index . Type == core . IndexType {
} else if index . Type == core . IndexType {
rIdxName = indexName ( tb Name, idxName )
rIdxName = indexName ( idxPrefix Name, idxName )
}
}
sql := fmt . Sprintf ( "DROP INDEX %v" , statement . Engine . Quote ( rIdxName ) )
sql := fmt . Sprintf ( "DROP INDEX %v" , statement . Engine . Quote ( statement . Engine . TableName ( rIdxName , true ) ) )
if statement . Engine . dialect . IndexOnTable ( ) {
if statement . Engine . dialect . IndexOnTable ( ) {
sql += fmt . Sprintf ( " ON %v" , statement . Engine . Quote ( s tatement . Ta ble Name( ) ) )
sql += fmt . Sprintf ( " ON %v" , statement . Engine . Quote ( tbName ) )
}
}
sqls = append ( sqls , sql )
sqls = append ( sqls , sql )
}
}
@ -949,7 +921,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
v := rValue ( bean )
v := rValue ( bean )
isStruct := v . Kind ( ) == reflect . Struct
isStruct := v . Kind ( ) == reflect . Struct
if isStruct {
if isStruct {
statement . setRefValue ( v )
statement . setRefBean ( bean )
}
}
var columnStr = statement . ColumnStr
var columnStr = statement . ColumnStr
@ -982,13 +954,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if err := statement . mergeConds ( bean ) ; err != nil {
if err := statement . mergeConds ( bean ) ; err != nil {
return "" , nil , err
return "" , nil , err
}
}
} else {
if err := statement . processIDParam ( ) ; err != nil {
return "" , nil , err
}
}
}
condSQL , condArgs , err := builder . ToSQL ( statement . cond )
condSQL , condArgs , err := builder . ToSQL ( statement . cond )
if err != nil {
if err != nil {
return "" , nil , err
return "" , nil , err
}
}
sqlStr , err := statement . genSelectSQL ( columnStr , condSQL , true )
sqlStr , err := statement . genSelectSQL ( columnStr , condSQL , true , true )
if err != nil {
if err != nil {
return "" , nil , err
return "" , nil , err
}
}
@ -1001,7 +977,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var condArgs [ ] interface { }
var condArgs [ ] interface { }
var err error
var err error
if len ( beans ) > 0 {
if len ( beans ) > 0 {
statement . setRefValue ( rValue ( beans [ 0 ] ) )
statement . setRefBean ( beans [ 0 ] )
condSQL , condArgs , err = statement . genConds ( beans [ 0 ] )
condSQL , condArgs , err = statement . genConds ( beans [ 0 ] )
} else {
} else {
condSQL , condArgs , err = builder . ToSQL ( statement . cond )
condSQL , condArgs , err = builder . ToSQL ( statement . cond )
@ -1018,7 +994,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)"
selectSQL = "count(*)"
}
}
}
}
sqlStr , err := statement . genSelectSQL ( selectSQL , condSQL , false )
sqlStr , err := statement . genSelectSQL ( selectSQL , condSQL , false , false )
if err != nil {
if err != nil {
return "" , nil , err
return "" , nil , err
}
}
@ -1027,7 +1003,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
}
}
func ( statement * Statement ) genSumSQL ( bean interface { } , columns ... string ) ( string , [ ] interface { } , error ) {
func ( statement * Statement ) genSumSQL ( bean interface { } , columns ... string ) ( string , [ ] interface { } , error ) {
statement . setRefValue ( rValue ( bean ) )
statement . setRefBean ( bean )
var sumStrs = make ( [ ] string , 0 , len ( columns ) )
var sumStrs = make ( [ ] string , 0 , len ( columns ) )
for _ , colName := range columns {
for _ , colName := range columns {
@ -1043,7 +1019,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "" , nil , err
return "" , nil , err
}
}
sqlStr , err := statement . genSelectSQL ( sumSelect , condSQL , true )
sqlStr , err := statement . genSelectSQL ( sumSelect , condSQL , true , true )
if err != nil {
if err != nil {
return "" , nil , err
return "" , nil , err
}
}
@ -1051,7 +1027,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr , append ( statement . joinArgs , condArgs ... ) , nil
return sqlStr , append ( statement . joinArgs , condArgs ... ) , nil
}
}
func ( statement * Statement ) genSelectSQL ( columnStr , condSQL string , needLimit bool ) ( a string , err error ) {
func ( statement * Statement ) genSelectSQL ( columnStr , condSQL string , needLimit , needOrderBy bool ) ( a string , err error ) {
var distinct string
var distinct string
if statement . IsDistinct && ! strings . HasPrefix ( columnStr , "count" ) {
if statement . IsDistinct && ! strings . HasPrefix ( columnStr , "count" ) {
distinct = "DISTINCT "
distinct = "DISTINCT "
@ -1062,10 +1038,6 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
var top string
var top string
var mssqlCondi string
var mssqlCondi string
if err := statement . processIDParam ( ) ; err != nil {
return "" , err
}
var buf bytes . Buffer
var buf bytes . Buffer
if len ( condSQL ) > 0 {
if len ( condSQL ) > 0 {
fmt . Fprintf ( & buf , " WHERE %v" , condSQL )
fmt . Fprintf ( & buf , " WHERE %v" , condSQL )
@ -1118,9 +1090,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
}
}
var orderStr string
var orderStr string
if len ( statement . OrderStr ) > 0 {
if needOrderBy && len ( statement . OrderStr ) > 0 {
orderStr = " ORDER BY " + statement . OrderStr
orderStr = " ORDER BY " + statement . OrderStr
}
}
var groupStr string
var groupStr string
if len ( statement . GroupByStr ) > 0 {
if len ( statement . GroupByStr ) > 0 {
groupStr = " GROUP BY " + statement . GroupByStr
groupStr = " GROUP BY " + statement . GroupByStr
@ -1146,7 +1119,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
if statement . HavingStr != "" {
if statement . HavingStr != "" {
a = fmt . Sprintf ( "%v %v" , a , statement . HavingStr )
a = fmt . Sprintf ( "%v %v" , a , statement . HavingStr )
}
}
if statement . OrderStr != "" {
if needOrderBy && statement . OrderStr != "" {
a = fmt . Sprintf ( "%v ORDER BY %v" , a , statement . OrderStr )
a = fmt . Sprintf ( "%v ORDER BY %v" , a , statement . OrderStr )
}
}
if needLimit {
if needLimit {
@ -1170,7 +1143,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
}
}
func ( statement * Statement ) processIDParam ( ) error {
func ( statement * Statement ) processIDParam ( ) error {
if statement . idParam == nil {
if statement . idParam == nil || statement . RefTable == nil {
return nil
return nil
}
}