mirror of https://github.com/gogits/gogs.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
461 lines
9.8 KiB
461 lines
9.8 KiB
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package |
|
// |
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. |
|
// |
|
// This Source Code Form is subject to the terms of the Mozilla Public |
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file, |
|
// You can obtain one at http://mozilla.org/MPL/2.0/. |
|
|
|
package mysql |
|
|
|
import ( |
|
"database/sql/driver" |
|
"io" |
|
"net" |
|
"strconv" |
|
"strings" |
|
"time" |
|
) |
|
|
|
// a copy of context.Context for Go 1.7 and earlier |
|
type mysqlContext interface { |
|
Done() <-chan struct{} |
|
Err() error |
|
|
|
// defined in context.Context, but not used in this driver: |
|
// Deadline() (deadline time.Time, ok bool) |
|
// Value(key interface{}) interface{} |
|
} |
|
|
|
type mysqlConn struct { |
|
buf buffer |
|
netConn net.Conn |
|
affectedRows uint64 |
|
insertId uint64 |
|
cfg *Config |
|
maxAllowedPacket int |
|
maxWriteSize int |
|
writeTimeout time.Duration |
|
flags clientFlag |
|
status statusFlag |
|
sequence uint8 |
|
parseTime bool |
|
|
|
// for context support (Go 1.8+) |
|
watching bool |
|
watcher chan<- mysqlContext |
|
closech chan struct{} |
|
finished chan<- struct{} |
|
canceled atomicError // set non-nil if conn is canceled |
|
closed atomicBool // set when conn is closed, before closech is closed |
|
} |
|
|
|
// Handles parameters set in DSN after the connection is established |
|
func (mc *mysqlConn) handleParams() (err error) { |
|
for param, val := range mc.cfg.Params { |
|
switch param { |
|
// Charset |
|
case "charset": |
|
charsets := strings.Split(val, ",") |
|
for i := range charsets { |
|
// ignore errors here - a charset may not exist |
|
err = mc.exec("SET NAMES " + charsets[i]) |
|
if err == nil { |
|
break |
|
} |
|
} |
|
if err != nil { |
|
return |
|
} |
|
|
|
// System Vars |
|
default: |
|
err = mc.exec("SET " + param + "=" + val + "") |
|
if err != nil { |
|
return |
|
} |
|
} |
|
} |
|
|
|
return |
|
} |
|
|
|
func (mc *mysqlConn) markBadConn(err error) error { |
|
if mc == nil { |
|
return err |
|
} |
|
if err != errBadConnNoWrite { |
|
return err |
|
} |
|
return driver.ErrBadConn |
|
} |
|
|
|
func (mc *mysqlConn) Begin() (driver.Tx, error) { |
|
return mc.begin(false) |
|
} |
|
|
|
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { |
|
if mc.closed.IsSet() { |
|
errLog.Print(ErrInvalidConn) |
|
return nil, driver.ErrBadConn |
|
} |
|
var q string |
|
if readOnly { |
|
q = "START TRANSACTION READ ONLY" |
|
} else { |
|
q = "START TRANSACTION" |
|
} |
|
err := mc.exec(q) |
|
if err == nil { |
|
return &mysqlTx{mc}, err |
|
} |
|
return nil, mc.markBadConn(err) |
|
} |
|
|
|
func (mc *mysqlConn) Close() (err error) { |
|
// Makes Close idempotent |
|
if !mc.closed.IsSet() { |
|
err = mc.writeCommandPacket(comQuit) |
|
} |
|
|
|
mc.cleanup() |
|
|
|
return |
|
} |
|
|
|
// Closes the network connection and unsets internal variables. Do not call this |
|
// function after successfully authentication, call Close instead. This function |
|
// is called before auth or on auth failure because MySQL will have already |
|
// closed the network connection. |
|
func (mc *mysqlConn) cleanup() { |
|
if !mc.closed.TrySet(true) { |
|
return |
|
} |
|
|
|
// Makes cleanup idempotent |
|
close(mc.closech) |
|
if mc.netConn == nil { |
|
return |
|
} |
|
if err := mc.netConn.Close(); err != nil { |
|
errLog.Print(err) |
|
} |
|
} |
|
|
|
func (mc *mysqlConn) error() error { |
|
if mc.closed.IsSet() { |
|
if err := mc.canceled.Value(); err != nil { |
|
return err |
|
} |
|
return ErrInvalidConn |
|
} |
|
return nil |
|
} |
|
|
|
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { |
|
if mc.closed.IsSet() { |
|
errLog.Print(ErrInvalidConn) |
|
return nil, driver.ErrBadConn |
|
} |
|
// Send command |
|
err := mc.writeCommandPacketStr(comStmtPrepare, query) |
|
if err != nil { |
|
return nil, mc.markBadConn(err) |
|
} |
|
|
|
stmt := &mysqlStmt{ |
|
mc: mc, |
|
} |
|
|
|
// Read Result |
|
columnCount, err := stmt.readPrepareResultPacket() |
|
if err == nil { |
|
if stmt.paramCount > 0 { |
|
if err = mc.readUntilEOF(); err != nil { |
|
return nil, err |
|
} |
|
} |
|
|
|
if columnCount > 0 { |
|
err = mc.readUntilEOF() |
|
} |
|
} |
|
|
|
return stmt, err |
|
} |
|
|
|
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { |
|
// Number of ? should be same to len(args) |
|
if strings.Count(query, "?") != len(args) { |
|
return "", driver.ErrSkip |
|
} |
|
|
|
buf := mc.buf.takeCompleteBuffer() |
|
if buf == nil { |
|
// can not take the buffer. Something must be wrong with the connection |
|
errLog.Print(ErrBusyBuffer) |
|
return "", ErrInvalidConn |
|
} |
|
buf = buf[:0] |
|
argPos := 0 |
|
|
|
for i := 0; i < len(query); i++ { |
|
q := strings.IndexByte(query[i:], '?') |
|
if q == -1 { |
|
buf = append(buf, query[i:]...) |
|
break |
|
} |
|
buf = append(buf, query[i:i+q]...) |
|
i += q |
|
|
|
arg := args[argPos] |
|
argPos++ |
|
|
|
if arg == nil { |
|
buf = append(buf, "NULL"...) |
|
continue |
|
} |
|
|
|
switch v := arg.(type) { |
|
case int64: |
|
buf = strconv.AppendInt(buf, v, 10) |
|
case float64: |
|
buf = strconv.AppendFloat(buf, v, 'g', -1, 64) |
|
case bool: |
|
if v { |
|
buf = append(buf, '1') |
|
} else { |
|
buf = append(buf, '0') |
|
} |
|
case time.Time: |
|
if v.IsZero() { |
|
buf = append(buf, "'0000-00-00'"...) |
|
} else { |
|
v := v.In(mc.cfg.Loc) |
|
v = v.Add(time.Nanosecond * 500) // To round under microsecond |
|
year := v.Year() |
|
year100 := year / 100 |
|
year1 := year % 100 |
|
month := v.Month() |
|
day := v.Day() |
|
hour := v.Hour() |
|
minute := v.Minute() |
|
second := v.Second() |
|
micro := v.Nanosecond() / 1000 |
|
|
|
buf = append(buf, []byte{ |
|
'\'', |
|
digits10[year100], digits01[year100], |
|
digits10[year1], digits01[year1], |
|
'-', |
|
digits10[month], digits01[month], |
|
'-', |
|
digits10[day], digits01[day], |
|
' ', |
|
digits10[hour], digits01[hour], |
|
':', |
|
digits10[minute], digits01[minute], |
|
':', |
|
digits10[second], digits01[second], |
|
}...) |
|
|
|
if micro != 0 { |
|
micro10000 := micro / 10000 |
|
micro100 := micro / 100 % 100 |
|
micro1 := micro % 100 |
|
buf = append(buf, []byte{ |
|
'.', |
|
digits10[micro10000], digits01[micro10000], |
|
digits10[micro100], digits01[micro100], |
|
digits10[micro1], digits01[micro1], |
|
}...) |
|
} |
|
buf = append(buf, '\'') |
|
} |
|
case []byte: |
|
if v == nil { |
|
buf = append(buf, "NULL"...) |
|
} else { |
|
buf = append(buf, "_binary'"...) |
|
if mc.status&statusNoBackslashEscapes == 0 { |
|
buf = escapeBytesBackslash(buf, v) |
|
} else { |
|
buf = escapeBytesQuotes(buf, v) |
|
} |
|
buf = append(buf, '\'') |
|
} |
|
case string: |
|
buf = append(buf, '\'') |
|
if mc.status&statusNoBackslashEscapes == 0 { |
|
buf = escapeStringBackslash(buf, v) |
|
} else { |
|
buf = escapeStringQuotes(buf, v) |
|
} |
|
buf = append(buf, '\'') |
|
default: |
|
return "", driver.ErrSkip |
|
} |
|
|
|
if len(buf)+4 > mc.maxAllowedPacket { |
|
return "", driver.ErrSkip |
|
} |
|
} |
|
if argPos != len(args) { |
|
return "", driver.ErrSkip |
|
} |
|
return string(buf), nil |
|
} |
|
|
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { |
|
if mc.closed.IsSet() { |
|
errLog.Print(ErrInvalidConn) |
|
return nil, driver.ErrBadConn |
|
} |
|
if len(args) != 0 { |
|
if !mc.cfg.InterpolateParams { |
|
return nil, driver.ErrSkip |
|
} |
|
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement |
|
prepared, err := mc.interpolateParams(query, args) |
|
if err != nil { |
|
return nil, err |
|
} |
|
query = prepared |
|
} |
|
mc.affectedRows = 0 |
|
mc.insertId = 0 |
|
|
|
err := mc.exec(query) |
|
if err == nil { |
|
return &mysqlResult{ |
|
affectedRows: int64(mc.affectedRows), |
|
insertId: int64(mc.insertId), |
|
}, err |
|
} |
|
return nil, mc.markBadConn(err) |
|
} |
|
|
|
// Internal function to execute commands |
|
func (mc *mysqlConn) exec(query string) error { |
|
// Send command |
|
if err := mc.writeCommandPacketStr(comQuery, query); err != nil { |
|
return mc.markBadConn(err) |
|
} |
|
|
|
// Read Result |
|
resLen, err := mc.readResultSetHeaderPacket() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if resLen > 0 { |
|
// columns |
|
if err := mc.readUntilEOF(); err != nil { |
|
return err |
|
} |
|
|
|
// rows |
|
if err := mc.readUntilEOF(); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return mc.discardResults() |
|
} |
|
|
|
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { |
|
return mc.query(query, args) |
|
} |
|
|
|
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { |
|
if mc.closed.IsSet() { |
|
errLog.Print(ErrInvalidConn) |
|
return nil, driver.ErrBadConn |
|
} |
|
if len(args) != 0 { |
|
if !mc.cfg.InterpolateParams { |
|
return nil, driver.ErrSkip |
|
} |
|
// try client-side prepare to reduce roundtrip |
|
prepared, err := mc.interpolateParams(query, args) |
|
if err != nil { |
|
return nil, err |
|
} |
|
query = prepared |
|
} |
|
// Send command |
|
err := mc.writeCommandPacketStr(comQuery, query) |
|
if err == nil { |
|
// Read Result |
|
var resLen int |
|
resLen, err = mc.readResultSetHeaderPacket() |
|
if err == nil { |
|
rows := new(textRows) |
|
rows.mc = mc |
|
|
|
if resLen == 0 { |
|
rows.rs.done = true |
|
|
|
switch err := rows.NextResultSet(); err { |
|
case nil, io.EOF: |
|
return rows, nil |
|
default: |
|
return nil, err |
|
} |
|
} |
|
|
|
// Columns |
|
rows.rs.columns, err = mc.readColumns(resLen) |
|
return rows, err |
|
} |
|
} |
|
return nil, mc.markBadConn(err) |
|
} |
|
|
|
// Gets the value of the given MySQL System Variable |
|
// The returned byte slice is only valid until the next read |
|
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { |
|
// Send command |
|
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { |
|
return nil, err |
|
} |
|
|
|
// Read Result |
|
resLen, err := mc.readResultSetHeaderPacket() |
|
if err == nil { |
|
rows := new(textRows) |
|
rows.mc = mc |
|
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} |
|
|
|
if resLen > 0 { |
|
// Columns |
|
if err := mc.readUntilEOF(); err != nil { |
|
return nil, err |
|
} |
|
} |
|
|
|
dest := make([]driver.Value, resLen) |
|
if err = rows.readRow(dest); err == nil { |
|
return dest[0].([]byte), mc.readUntilEOF() |
|
} |
|
} |
|
return nil, err |
|
} |
|
|
|
// finish is called when the query has canceled. |
|
func (mc *mysqlConn) cancel(err error) { |
|
mc.canceled.Set(err) |
|
mc.cleanup() |
|
} |
|
|
|
// finish is called when the query has succeeded. |
|
func (mc *mysqlConn) finish() { |
|
if !mc.watching || mc.finished == nil { |
|
return |
|
} |
|
select { |
|
case mc.finished <- struct{}{}: |
|
mc.watching = false |
|
case <-mc.closech: |
|
} |
|
}
|
|
|