mirror of https://github.com/gogits/gogs.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
610 lines
15 KiB
610 lines
15 KiB
8 years ago
|
package mssql
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"math"
|
||
|
"net"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"time"
|
||
|
"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
|
||
|
)
|
||
|
|
||
|
var driverInstance = &MssqlDriver{processQueryText: true}
|
||
|
var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
|
||
|
|
||
|
func init() {
|
||
|
sql.Register("mssql", driverInstance)
|
||
|
sql.Register("sqlserver", driverInstanceNoProcess)
|
||
|
}
|
||
|
|
||
|
// Abstract the dialer for testing and for non-TCP based connections.
|
||
|
type dialer interface {
|
||
|
Dial(addr string) (net.Conn, error)
|
||
|
}
|
||
|
|
||
|
var createDialer func(p *connectParams) dialer
|
||
|
|
||
|
type tcpDialer struct {
|
||
|
nd *net.Dialer
|
||
|
}
|
||
|
|
||
|
func (d tcpDialer) Dial(addr string) (net.Conn, error) {
|
||
|
return d.nd.Dial("tcp", addr)
|
||
|
}
|
||
|
|
||
|
type MssqlDriver struct {
|
||
|
log optionalLogger
|
||
|
|
||
|
processQueryText bool
|
||
|
}
|
||
|
|
||
|
func SetLogger(logger Logger) {
|
||
|
driverInstance.SetLogger(logger)
|
||
|
driverInstanceNoProcess.SetLogger(logger)
|
||
|
}
|
||
|
|
||
|
func (d *MssqlDriver) SetLogger(logger Logger) {
|
||
|
d.log = optionalLogger{logger}
|
||
|
}
|
||
|
|
||
|
type MssqlConn struct {
|
||
|
sess *tdsSession
|
||
|
transactionCtx context.Context
|
||
|
|
||
|
processQueryText bool
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
|
||
|
tokchan := make(chan tokenStruct, 5)
|
||
|
go processResponse(ctx, c.sess, tokchan)
|
||
|
for tok := range tokchan {
|
||
|
switch token := tok.(type) {
|
||
|
case doneStruct:
|
||
|
if token.isError() {
|
||
|
return token.getError()
|
||
|
}
|
||
|
case error:
|
||
|
return token
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) Commit() error {
|
||
|
if err := c.sendCommitRequest(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return c.simpleProcessResp(c.transactionCtx)
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) sendCommitRequest() error {
|
||
|
headers := []headerStruct{
|
||
|
{hdrtype: dataStmHdrTransDescr,
|
||
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
||
|
}
|
||
|
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
||
|
if c.sess.logFlags&logErrors != 0 {
|
||
|
c.sess.log.Printf("Failed to send CommitXact with %v", err)
|
||
|
}
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) Rollback() error {
|
||
|
if err := c.sendRollbackRequest(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return c.simpleProcessResp(c.transactionCtx)
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) sendRollbackRequest() error {
|
||
|
headers := []headerStruct{
|
||
|
{hdrtype: dataStmHdrTransDescr,
|
||
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
||
|
}
|
||
|
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
||
|
if c.sess.logFlags&logErrors != 0 {
|
||
|
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
|
||
|
}
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) Begin() (driver.Tx, error) {
|
||
|
return c.begin(context.Background(), isolationUseCurrent)
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
|
||
|
err := c.sendBeginRequest(ctx, tdsIsolation)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return c.processBeginResponse(ctx)
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
|
||
|
c.transactionCtx = ctx
|
||
|
headers := []headerStruct{
|
||
|
{hdrtype: dataStmHdrTransDescr,
|
||
|
data: transDescrHdr{0, 1}.pack()},
|
||
|
}
|
||
|
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
|
||
|
if c.sess.logFlags&logErrors != 0 {
|
||
|
c.sess.log.Printf("Failed to send BeginXact with %v", err)
|
||
|
}
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
|
||
|
if err := c.simpleProcessResp(ctx); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
// successful BEGINXACT request will return sess.tranid
|
||
|
// for started transaction
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
|
||
|
return d.open(dsn)
|
||
|
}
|
||
|
|
||
|
func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
|
||
|
params, err := parseConnectParams(dsn)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
sess, err := connect(d.log, params)
|
||
|
if err != nil {
|
||
|
// main server failed, try fail-over partner
|
||
|
if params.failOverPartner == "" {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
params.host = params.failOverPartner
|
||
|
if params.failOverPort != 0 {
|
||
|
params.port = params.failOverPort
|
||
|
}
|
||
|
|
||
|
sess, err = connect(d.log, params)
|
||
|
if err != nil {
|
||
|
// fail-over partner also failed, now fail
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
conn := &MssqlConn{sess, context.Background(), d.processQueryText}
|
||
|
conn.sess.log = d.log
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) Close() error {
|
||
|
return c.sess.buf.transport.Close()
|
||
|
}
|
||
|
|
||
|
type MssqlStmt struct {
|
||
|
c *MssqlConn
|
||
|
query string
|
||
|
paramCount int
|
||
|
notifSub *queryNotifSub
|
||
|
}
|
||
|
|
||
|
type queryNotifSub struct {
|
||
|
msgText string
|
||
|
options string
|
||
|
timeout uint32
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
|
||
|
return c.prepareContext(context.Background(), query)
|
||
|
}
|
||
|
|
||
|
func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
|
||
|
paramCount := -1
|
||
|
if c.processQueryText {
|
||
|
query, paramCount = parseParams(query)
|
||
|
}
|
||
|
return &MssqlStmt{c, query, paramCount, nil}, nil
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) Close() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
|
||
|
to := uint32(timeout / time.Second)
|
||
|
if to < 1 {
|
||
|
to = 1
|
||
|
}
|
||
|
s.notifSub = &queryNotifSub{id, options, to}
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) NumInput() int {
|
||
|
return s.paramCount
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
|
||
|
headers := []headerStruct{
|
||
|
{hdrtype: dataStmHdrTransDescr,
|
||
|
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
|
||
|
}
|
||
|
|
||
|
if s.notifSub != nil {
|
||
|
headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
|
||
|
data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
|
||
|
}
|
||
|
|
||
|
// no need to check number of parameters here, it is checked by database/sql
|
||
|
if s.c.sess.logFlags&logSQL != 0 {
|
||
|
s.c.sess.log.Println(s.query)
|
||
|
}
|
||
|
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
|
||
|
for i := 0; i < len(args); i++ {
|
||
|
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
|
||
|
}
|
||
|
|
||
|
}
|
||
|
if len(args) == 0 {
|
||
|
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
|
||
|
if s.c.sess.logFlags&logErrors != 0 {
|
||
|
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
|
||
|
}
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
} else {
|
||
|
params := make([]Param, len(args)+2)
|
||
|
decls := make([]string, len(args))
|
||
|
params[0] = makeStrParam(s.query)
|
||
|
for i, val := range args {
|
||
|
params[i+2], err = s.makeParam(val.Value)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
var name string
|
||
|
if len(val.Name) > 0 {
|
||
|
name = "@" + val.Name
|
||
|
} else {
|
||
|
name = fmt.Sprintf("@p%d", val.Ordinal)
|
||
|
}
|
||
|
params[i+2].Name = name
|
||
|
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
|
||
|
}
|
||
|
params[1] = makeStrParam(strings.Join(decls, ","))
|
||
|
if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
|
||
|
if s.c.sess.logFlags&logErrors != 0 {
|
||
|
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
|
||
|
}
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type namedValue struct {
|
||
|
Name string
|
||
|
Ordinal int
|
||
|
Value driver.Value
|
||
|
}
|
||
|
|
||
|
func convertOldArgs(args []driver.Value) []namedValue {
|
||
|
list := make([]namedValue, len(args))
|
||
|
for i, v := range args {
|
||
|
list[i] = namedValue{
|
||
|
Ordinal: i + 1,
|
||
|
Value: v,
|
||
|
}
|
||
|
}
|
||
|
return list
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||
|
return s.queryContext(context.Background(), convertOldArgs(args))
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
|
||
|
if err := s.sendQuery(args); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return s.processQueryResponse(ctx)
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
|
||
|
tokchan := make(chan tokenStruct, 5)
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
go processResponse(ctx, s.c.sess, tokchan)
|
||
|
// process metadata
|
||
|
var cols []columnStruct
|
||
|
loop:
|
||
|
for tok := range tokchan {
|
||
|
switch token := tok.(type) {
|
||
|
// by ignoring DONE token we effectively
|
||
|
// skip empty result-sets
|
||
|
// this improves results in queryes like that:
|
||
|
// set nocount on; select 1
|
||
|
// see TestIgnoreEmptyResults test
|
||
|
//case doneStruct:
|
||
|
//break loop
|
||
|
case []columnStruct:
|
||
|
cols = token
|
||
|
break loop
|
||
|
case doneStruct:
|
||
|
if token.isError() {
|
||
|
return nil, token.getError()
|
||
|
}
|
||
|
case error:
|
||
|
return nil, token
|
||
|
}
|
||
|
}
|
||
|
res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||
|
return s.exec(context.Background(), convertOldArgs(args))
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
||
|
if err := s.sendQuery(args); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return s.processExec(ctx)
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
|
||
|
tokchan := make(chan tokenStruct, 5)
|
||
|
go processResponse(ctx, s.c.sess, tokchan)
|
||
|
var rowCount int64
|
||
|
for token := range tokchan {
|
||
|
switch token := token.(type) {
|
||
|
case doneInProcStruct:
|
||
|
if token.Status&doneCount != 0 {
|
||
|
rowCount += int64(token.RowCount)
|
||
|
}
|
||
|
case doneStruct:
|
||
|
if token.Status&doneCount != 0 {
|
||
|
rowCount += int64(token.RowCount)
|
||
|
}
|
||
|
if token.isError() {
|
||
|
return nil, token.getError()
|
||
|
}
|
||
|
case error:
|
||
|
return nil, token
|
||
|
}
|
||
|
}
|
||
|
return &MssqlResult{s.c, rowCount}, nil
|
||
|
}
|
||
|
|
||
|
type MssqlRows struct {
|
||
|
sess *tdsSession
|
||
|
cols []columnStruct
|
||
|
tokchan chan tokenStruct
|
||
|
|
||
|
nextCols []columnStruct
|
||
|
|
||
|
cancel func()
|
||
|
}
|
||
|
|
||
|
func (rc *MssqlRows) Close() error {
|
||
|
rc.cancel()
|
||
|
for _ = range rc.tokchan {
|
||
|
}
|
||
|
rc.tokchan = nil
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (rc *MssqlRows) Columns() (res []string) {
|
||
|
res = make([]string, len(rc.cols))
|
||
|
for i, col := range rc.cols {
|
||
|
res[i] = col.ColName
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (rc *MssqlRows) Next(dest []driver.Value) error {
|
||
|
if rc.nextCols != nil {
|
||
|
return io.EOF
|
||
|
}
|
||
|
for tok := range rc.tokchan {
|
||
|
switch tokdata := tok.(type) {
|
||
|
case []columnStruct:
|
||
|
rc.nextCols = tokdata
|
||
|
return io.EOF
|
||
|
case []interface{}:
|
||
|
for i := range dest {
|
||
|
dest[i] = tokdata[i]
|
||
|
}
|
||
|
return nil
|
||
|
case doneStruct:
|
||
|
if tokdata.isError() {
|
||
|
return tokdata.getError()
|
||
|
}
|
||
|
case error:
|
||
|
return tokdata
|
||
|
}
|
||
|
}
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
func (rc *MssqlRows) HasNextResultSet() bool {
|
||
|
return rc.nextCols != nil
|
||
|
}
|
||
|
|
||
|
func (rc *MssqlRows) NextResultSet() error {
|
||
|
rc.cols = rc.nextCols
|
||
|
rc.nextCols = nil
|
||
|
if rc.cols == nil {
|
||
|
return io.EOF
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// It should return
|
||
|
// the value type that can be used to scan types into. For example, the database
|
||
|
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
|
||
|
func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
|
||
|
return makeGoLangScanType(r.cols[index].ti)
|
||
|
}
|
||
|
|
||
|
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
|
||
|
// database system type name without the length. Type names should be uppercase.
|
||
|
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
|
||
|
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
|
||
|
// "TIMESTAMP".
|
||
|
func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
|
||
|
return makeGoLangTypeName(r.cols[index].ti)
|
||
|
}
|
||
|
|
||
|
// RowsColumnTypeLength may be implemented by Rows. It should return the length
|
||
|
// of the column type if the column is a variable length type. If the column is
|
||
|
// not a variable length type ok should return false.
|
||
|
// If length is not limited other than system limits, it should return math.MaxInt64.
|
||
|
// The following are examples of returned values for various types:
|
||
|
// TEXT (math.MaxInt64, true)
|
||
|
// varchar(10) (10, true)
|
||
|
// nvarchar(10) (10, true)
|
||
|
// decimal (0, false)
|
||
|
// int (0, false)
|
||
|
// bytea(30) (30, true)
|
||
|
func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
|
||
|
return makeGoLangTypeLength(r.cols[index].ti)
|
||
|
}
|
||
|
|
||
|
// It should return
|
||
|
// the precision and scale for decimal types. If not applicable, ok should be false.
|
||
|
// The following are examples of returned values for various types:
|
||
|
// decimal(38, 4) (38, 4, true)
|
||
|
// int (0, 0, false)
|
||
|
// decimal (math.MaxInt64, math.MaxInt64, true)
|
||
|
func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
|
||
|
return makeGoLangTypePrecisionScale(r.cols[index].ti)
|
||
|
}
|
||
|
|
||
|
// The nullable value should
|
||
|
// be true if it is known the column may be null, or false if the column is known
|
||
|
// to be not nullable.
|
||
|
// If the column nullability is unknown, ok should be false.
|
||
|
func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||
|
nullable = r.cols[index].Flags&colFlagNullable != 0
|
||
|
ok = true
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func makeStrParam(val string) (res Param) {
|
||
|
res.ti.TypeId = typeNVarChar
|
||
|
res.buffer = str2ucs2(val)
|
||
|
res.ti.Size = len(res.buffer)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
||
|
if val == nil {
|
||
|
res.ti.TypeId = typeNVarChar
|
||
|
res.buffer = nil
|
||
|
res.ti.Size = 2
|
||
|
return
|
||
|
}
|
||
|
switch val := val.(type) {
|
||
|
case int64:
|
||
|
res.ti.TypeId = typeIntN
|
||
|
res.buffer = make([]byte, 8)
|
||
|
res.ti.Size = 8
|
||
|
binary.LittleEndian.PutUint64(res.buffer, uint64(val))
|
||
|
case float64:
|
||
|
res.ti.TypeId = typeFltN
|
||
|
res.ti.Size = 8
|
||
|
res.buffer = make([]byte, 8)
|
||
|
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
|
||
|
case []byte:
|
||
|
res.ti.TypeId = typeBigVarBin
|
||
|
res.ti.Size = len(val)
|
||
|
res.buffer = val
|
||
|
case string:
|
||
|
res = makeStrParam(val)
|
||
|
case bool:
|
||
|
res.ti.TypeId = typeBitN
|
||
|
res.ti.Size = 1
|
||
|
res.buffer = make([]byte, 1)
|
||
|
if val {
|
||
|
res.buffer[0] = 1
|
||
|
}
|
||
|
case time.Time:
|
||
|
if s.c.sess.loginAck.TDSVersion >= verTDS73 {
|
||
|
res.ti.TypeId = typeDateTimeOffsetN
|
||
|
res.ti.Scale = 7
|
||
|
res.ti.Size = 10
|
||
|
buf := make([]byte, 10)
|
||
|
res.buffer = buf
|
||
|
days, ns := dateTime2(val)
|
||
|
ns /= 100
|
||
|
buf[0] = byte(ns)
|
||
|
buf[1] = byte(ns >> 8)
|
||
|
buf[2] = byte(ns >> 16)
|
||
|
buf[3] = byte(ns >> 24)
|
||
|
buf[4] = byte(ns >> 32)
|
||
|
buf[5] = byte(days)
|
||
|
buf[6] = byte(days >> 8)
|
||
|
buf[7] = byte(days >> 16)
|
||
|
_, offset := val.Zone()
|
||
|
offset /= 60
|
||
|
buf[8] = byte(offset)
|
||
|
buf[9] = byte(offset >> 8)
|
||
|
} else {
|
||
|
res.ti.TypeId = typeDateTimeN
|
||
|
res.ti.Size = 8
|
||
|
res.buffer = make([]byte, 8)
|
||
|
ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
||
|
dur := val.Sub(ref)
|
||
|
days := dur / (24 * time.Hour)
|
||
|
tm := (300 * (dur % (24 * time.Hour))) / time.Second
|
||
|
binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
|
||
|
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
|
||
|
}
|
||
|
default:
|
||
|
err = fmt.Errorf("mssql: unknown type for %T", val)
|
||
|
return
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type MssqlResult struct {
|
||
|
c *MssqlConn
|
||
|
rowsAffected int64
|
||
|
}
|
||
|
|
||
|
func (r *MssqlResult) RowsAffected() (int64, error) {
|
||
|
return r.rowsAffected, nil
|
||
|
}
|
||
|
|
||
|
func (r *MssqlResult) LastInsertId() (int64, error) {
|
||
|
s, err := r.c.Prepare("select cast(@@identity as bigint)")
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
defer s.Close()
|
||
|
rows, err := s.Query(nil)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
dest := make([]driver.Value, 1)
|
||
|
err = rows.Next(dest)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
if dest[0] == nil {
|
||
|
return -1, errors.New("There is no generated identity value")
|
||
|
}
|
||
|
lastInsertId := dest[0].(int64)
|
||
|
return lastInsertId, nil
|
||
|
}
|