| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609 | package mssqlimport (	"database/sql"	"database/sql/driver"	"encoding/binary"	"errors"	"fmt"	"io"	"math"	"net"	"reflect"	"strings"	"time"	"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.)var driverInstance = &MssqlDriver{processQueryText: true}var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}func init() {	sql.Register("mssql", driverInstance)	sql.Register("sqlserver", driverInstanceNoProcess)}// Abstract the dialer for testing and for non-TCP based connections.type dialer interface {	Dial(addr string) (net.Conn, error)}var createDialer func(p *connectParams) dialertype tcpDialer struct {	nd *net.Dialer}func (d tcpDialer) Dial(addr string) (net.Conn, error) {	return d.nd.Dial("tcp", addr)}type MssqlDriver struct {	log optionalLogger	processQueryText bool}func SetLogger(logger Logger) {	driverInstance.SetLogger(logger)	driverInstanceNoProcess.SetLogger(logger)}func (d *MssqlDriver) SetLogger(logger Logger) {	d.log = optionalLogger{logger}}type MssqlConn struct {	sess           *tdsSession	transactionCtx context.Context	processQueryText bool}func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {	tokchan := make(chan tokenStruct, 5)	go processResponse(ctx, c.sess, tokchan)	for tok := range tokchan {		switch token := tok.(type) {		case doneStruct:			if token.isError() {				return token.getError()			}		case error:			return token		}	}	return nil}func (c *MssqlConn) Commit() error {	if err := c.sendCommitRequest(); err != nil {		return err	}	return c.simpleProcessResp(c.transactionCtx)}func (c *MssqlConn) sendCommitRequest() error {	headers := []headerStruct{		{hdrtype: dataStmHdrTransDescr,			data: transDescrHdr{c.sess.tranid, 1}.pack()},	}	if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {		if c.sess.logFlags&logErrors != 0 {			c.sess.log.Printf("Failed to send CommitXact with %v", err)		}		return driver.ErrBadConn	}	return nil}func (c *MssqlConn) Rollback() error {	if err := c.sendRollbackRequest(); err != nil {		return err	}	return c.simpleProcessResp(c.transactionCtx)}func (c *MssqlConn) sendRollbackRequest() error {	headers := []headerStruct{		{hdrtype: dataStmHdrTransDescr,			data: transDescrHdr{c.sess.tranid, 1}.pack()},	}	if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {		if c.sess.logFlags&logErrors != 0 {			c.sess.log.Printf("Failed to send RollbackXact with %v", err)		}		return driver.ErrBadConn	}	return nil}func (c *MssqlConn) Begin() (driver.Tx, error) {	return c.begin(context.Background(), isolationUseCurrent)}func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {	err := c.sendBeginRequest(ctx, tdsIsolation)	if err != nil {		return nil, err	}	return c.processBeginResponse(ctx)}func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {	c.transactionCtx = ctx	headers := []headerStruct{		{hdrtype: dataStmHdrTransDescr,			data: transDescrHdr{0, 1}.pack()},	}	if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {		if c.sess.logFlags&logErrors != 0 {			c.sess.log.Printf("Failed to send BeginXact with %v", err)		}		return driver.ErrBadConn	}	return nil}func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {	if err := c.simpleProcessResp(ctx); err != nil {		return nil, err	}	// successful BEGINXACT request will return sess.tranid	// for started transaction	return c, nil}func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {	return d.open(dsn)}func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {	params, err := parseConnectParams(dsn)	if err != nil {		return nil, err	}	sess, err := connect(d.log, params)	if err != nil {		// main server failed, try fail-over partner		if params.failOverPartner == "" {			return nil, err		}		params.host = params.failOverPartner		if params.failOverPort != 0 {			params.port = params.failOverPort		}		sess, err = connect(d.log, params)		if err != nil {			// fail-over partner also failed, now fail			return nil, err		}	}	conn := &MssqlConn{sess, context.Background(), d.processQueryText}	conn.sess.log = d.log	return conn, nil}func (c *MssqlConn) Close() error {	return c.sess.buf.transport.Close()}type MssqlStmt struct {	c          *MssqlConn	query      string	paramCount int	notifSub   *queryNotifSub}type queryNotifSub struct {	msgText string	options string	timeout uint32}func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {	return c.prepareContext(context.Background(), query)}func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {	paramCount := -1	if c.processQueryText {		query, paramCount = parseParams(query)	}	return &MssqlStmt{c, query, paramCount, nil}, nil}func (s *MssqlStmt) Close() error {	return nil}func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {	to := uint32(timeout / time.Second)	if to < 1 {		to = 1	}	s.notifSub = &queryNotifSub{id, options, to}}func (s *MssqlStmt) NumInput() int {	return s.paramCount}func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {	headers := []headerStruct{		{hdrtype: dataStmHdrTransDescr,			data: transDescrHdr{s.c.sess.tranid, 1}.pack()},	}	if s.notifSub != nil {		headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,			data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})	}	// no need to check number of parameters here, it is checked by database/sql	if s.c.sess.logFlags&logSQL != 0 {		s.c.sess.log.Println(s.query)	}	if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {		for i := 0; i < len(args); i++ {			s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])		}	}	if len(args) == 0 {		if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {			if s.c.sess.logFlags&logErrors != 0 {				s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)			}			return driver.ErrBadConn		}	} else {		params := make([]Param, len(args)+2)		decls := make([]string, len(args))		params[0] = makeStrParam(s.query)		for i, val := range args {			params[i+2], err = s.makeParam(val.Value)			if err != nil {				return			}			var name string			if len(val.Name) > 0 {				name = "@" + val.Name			} else {				name = fmt.Sprintf("@p%d", val.Ordinal)			}			params[i+2].Name = name			decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))		}		params[1] = makeStrParam(strings.Join(decls, ","))		if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {			if s.c.sess.logFlags&logErrors != 0 {				s.c.sess.log.Printf("Failed to send Rpc with %v", err)			}			return driver.ErrBadConn		}	}	return}type namedValue struct {	Name    string	Ordinal int	Value   driver.Value}func convertOldArgs(args []driver.Value) []namedValue {	list := make([]namedValue, len(args))	for i, v := range args {		list[i] = namedValue{			Ordinal: i + 1,			Value:   v,		}	}	return list}func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {	return s.queryContext(context.Background(), convertOldArgs(args))}func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {	if err := s.sendQuery(args); err != nil {		return nil, err	}	return s.processQueryResponse(ctx)}func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {	tokchan := make(chan tokenStruct, 5)	ctx, cancel := context.WithCancel(ctx)	go processResponse(ctx, s.c.sess, tokchan)	// process metadata	var cols []columnStructloop:	for tok := range tokchan {		switch token := tok.(type) {		// by ignoring DONE token we effectively		// skip empty result-sets		// this improves results in queryes like that:		// set nocount on; select 1		// see TestIgnoreEmptyResults test		//case doneStruct:		//break loop		case []columnStruct:			cols = token			break loop		case doneStruct:			if token.isError() {				return nil, token.getError()			}		case error:			return nil, token		}	}	res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}	return}func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {	return s.exec(context.Background(), convertOldArgs(args))}func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {	if err := s.sendQuery(args); err != nil {		return nil, err	}	return s.processExec(ctx)}func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {	tokchan := make(chan tokenStruct, 5)	go processResponse(ctx, s.c.sess, tokchan)	var rowCount int64	for token := range tokchan {		switch token := token.(type) {		case doneInProcStruct:			if token.Status&doneCount != 0 {				rowCount += int64(token.RowCount)			}		case doneStruct:			if token.Status&doneCount != 0 {				rowCount += int64(token.RowCount)			}			if token.isError() {				return nil, token.getError()			}		case error:			return nil, token		}	}	return &MssqlResult{s.c, rowCount}, nil}type MssqlRows struct {	sess    *tdsSession	cols    []columnStruct	tokchan chan tokenStruct	nextCols []columnStruct	cancel func()}func (rc *MssqlRows) Close() error {	rc.cancel()	for _ = range rc.tokchan {	}	rc.tokchan = nil	return nil}func (rc *MssqlRows) Columns() (res []string) {	res = make([]string, len(rc.cols))	for i, col := range rc.cols {		res[i] = col.ColName	}	return}func (rc *MssqlRows) Next(dest []driver.Value) error {	if rc.nextCols != nil {		return io.EOF	}	for tok := range rc.tokchan {		switch tokdata := tok.(type) {		case []columnStruct:			rc.nextCols = tokdata			return io.EOF		case []interface{}:			for i := range dest {				dest[i] = tokdata[i]			}			return nil		case doneStruct:			if tokdata.isError() {				return tokdata.getError()			}		case error:			return tokdata		}	}	return io.EOF}func (rc *MssqlRows) HasNextResultSet() bool {	return rc.nextCols != nil}func (rc *MssqlRows) NextResultSet() error {	rc.cols = rc.nextCols	rc.nextCols = nil	if rc.cols == nil {		return io.EOF	}	return nil}// It should return// the value type that can be used to scan types into. For example, the database// column type "bigint" this should return "reflect.TypeOf(int64(0))".func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {	return makeGoLangScanType(r.cols[index].ti)}// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the// database system type name without the length. Type names should be uppercase.// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",// "TIMESTAMP".func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {	return makeGoLangTypeName(r.cols[index].ti)}// RowsColumnTypeLength may be implemented by Rows. It should return the length// of the column type if the column is a variable length type. If the column is// not a variable length type ok should return false.// If length is not limited other than system limits, it should return math.MaxInt64.// The following are examples of returned values for various types://   TEXT          (math.MaxInt64, true)//   varchar(10)   (10, true)//   nvarchar(10)  (10, true)//   decimal       (0, false)//   int           (0, false)//   bytea(30)     (30, true)func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {	return makeGoLangTypeLength(r.cols[index].ti)}// It should return// the precision and scale for decimal types. If not applicable, ok should be false.// The following are examples of returned values for various types://   decimal(38, 4)    (38, 4, true)//   int               (0, 0, false)//   decimal           (math.MaxInt64, math.MaxInt64, true)func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {	return makeGoLangTypePrecisionScale(r.cols[index].ti)}// The nullable value should// be true if it is known the column may be null, or false if the column is known// to be not nullable.// If the column nullability is unknown, ok should be false.func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {	nullable = r.cols[index].Flags&colFlagNullable != 0	ok = true	return}func makeStrParam(val string) (res Param) {	res.ti.TypeId = typeNVarChar	res.buffer = str2ucs2(val)	res.ti.Size = len(res.buffer)	return}func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {	if val == nil {		res.ti.TypeId = typeNVarChar		res.buffer = nil		res.ti.Size = 2		return	}	switch val := val.(type) {	case int64:		res.ti.TypeId = typeIntN		res.buffer = make([]byte, 8)		res.ti.Size = 8		binary.LittleEndian.PutUint64(res.buffer, uint64(val))	case float64:		res.ti.TypeId = typeFltN		res.ti.Size = 8		res.buffer = make([]byte, 8)		binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))	case []byte:		res.ti.TypeId = typeBigVarBin		res.ti.Size = len(val)		res.buffer = val	case string:		res = makeStrParam(val)	case bool:		res.ti.TypeId = typeBitN		res.ti.Size = 1		res.buffer = make([]byte, 1)		if val {			res.buffer[0] = 1		}	case time.Time:		if s.c.sess.loginAck.TDSVersion >= verTDS73 {			res.ti.TypeId = typeDateTimeOffsetN			res.ti.Scale = 7			res.ti.Size = 10			buf := make([]byte, 10)			res.buffer = buf			days, ns := dateTime2(val)			ns /= 100			buf[0] = byte(ns)			buf[1] = byte(ns >> 8)			buf[2] = byte(ns >> 16)			buf[3] = byte(ns >> 24)			buf[4] = byte(ns >> 32)			buf[5] = byte(days)			buf[6] = byte(days >> 8)			buf[7] = byte(days >> 16)			_, offset := val.Zone()			offset /= 60			buf[8] = byte(offset)			buf[9] = byte(offset >> 8)		} else {			res.ti.TypeId = typeDateTimeN			res.ti.Size = 8			res.buffer = make([]byte, 8)			ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)			dur := val.Sub(ref)			days := dur / (24 * time.Hour)			tm := (300 * (dur % (24 * time.Hour))) / time.Second			binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))			binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))		}	default:		err = fmt.Errorf("mssql: unknown type for %T", val)		return	}	return}type MssqlResult struct {	c            *MssqlConn	rowsAffected int64}func (r *MssqlResult) RowsAffected() (int64, error) {	return r.rowsAffected, nil}func (r *MssqlResult) LastInsertId() (int64, error) {	s, err := r.c.Prepare("select cast(@@identity as bigint)")	if err != nil {		return 0, err	}	defer s.Close()	rows, err := s.Query(nil)	if err != nil {		return 0, err	}	defer rows.Close()	dest := make([]driver.Value, 1)	err = rows.Next(dest)	if err != nil {		return 0, err	}	if dest[0] == nil {		return -1, errors.New("There is no generated identity value")	}	lastInsertId := dest[0].(int64)	return lastInsertId, nil}
 |