| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325 | package mssqlimport (	"crypto/tls"	"crypto/x509"	"encoding/binary"	"errors"	"fmt"	"io"	"io/ioutil"	"net"	"net/url"	"os"	"sort"	"strconv"	"strings"	"time"	"unicode"	"unicode/utf16"	"unicode/utf8"	"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.)func parseInstances(msg []byte) map[string]map[string]string {	results := map[string]map[string]string{}	if len(msg) > 3 && msg[0] == 5 {		out_s := string(msg[3:])		tokens := strings.Split(out_s, ";")		instdict := map[string]string{}		got_name := false		var name string		for _, token := range tokens {			if got_name {				instdict[name] = token				got_name = false			} else {				name = token				if len(name) == 0 {					if len(instdict) == 0 {						break					}					results[strings.ToUpper(instdict["InstanceName"])] = instdict					instdict = map[string]string{}					continue				}				got_name = true			}		}	}	return results}func getInstances(address string) (map[string]map[string]string, error) {	conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)	if err != nil {		return nil, err	}	defer conn.Close()	conn.SetDeadline(time.Now().Add(5 * time.Second))	_, err = conn.Write([]byte{3})	if err != nil {		return nil, err	}	var resp = make([]byte, 16*1024-1)	read, err := conn.Read(resp)	if err != nil {		return nil, err	}	return parseInstances(resp[:read]), nil}// tds versionsconst (	verTDS70     = 0x70000000	verTDS71     = 0x71000000	verTDS71rev1 = 0x71000001	verTDS72     = 0x72090002	verTDS73A    = 0x730A0003	verTDS73     = verTDS73A	verTDS73B    = 0x730B0003	verTDS74     = 0x74000004)// packet types// https://msdn.microsoft.com/en-us/library/dd304214.aspxconst (	packSQLBatch   packetType = 1	packRPCRequest            = 3	packReply                 = 4	// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx	// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx	packAttention = 6	packBulkLoadBCP = 7	packTransMgrReq = 14	packNormal      = 15	packLogin7      = 16	packSSPIMessage = 17	packPrelogin    = 18)// prelogin fields// http://msdn.microsoft.com/en-us/library/dd357559.aspxconst (	preloginVERSION    = 0	preloginENCRYPTION = 1	preloginINSTOPT    = 2	preloginTHREADID   = 3	preloginMARS       = 4	preloginTRACEID    = 5	preloginTERMINATOR = 0xff)const (	encryptOff    = 0 // Encryption is available but off.	encryptOn     = 1 // Encryption is available and on.	encryptNotSup = 2 // Encryption is not available.	encryptReq    = 3 // Encryption is required.)type tdsSession struct {	buf          *tdsBuffer	loginAck     loginAckStruct	database     string	partner      string	columns      []columnStruct	tranid       uint64	logFlags     uint64	log          optionalLogger	routedServer string	routedPort   uint16}const (	logErrors      = 1	logMessages    = 2	logRows        = 4	logSQL         = 8	logParams      = 16	logTransaction = 32	logDebug       = 64)type columnStruct struct {	UserType uint32	Flags    uint16	ColName  string	ti       typeInfo}type KeySlice []uint8func (p KeySlice) Len() int           { return len(p) }func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }func (p KeySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }// http://msdn.microsoft.com/en-us/library/dd357559.aspxfunc writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {	var err error	w.BeginPacket(packPrelogin)	offset := uint16(5*len(fields) + 1)	keys := make(KeySlice, 0, len(fields))	for k, _ := range fields {		keys = append(keys, k)	}	sort.Sort(keys)	// writing header	for _, k := range keys {		err = w.WriteByte(k)		if err != nil {			return err		}		err = binary.Write(w, binary.BigEndian, offset)		if err != nil {			return err		}		v := fields[k]		size := uint16(len(v))		err = binary.Write(w, binary.BigEndian, size)		if err != nil {			return err		}		offset += size	}	err = w.WriteByte(preloginTERMINATOR)	if err != nil {		return err	}	// writing values	for _, k := range keys {		v := fields[k]		written, err := w.Write(v)		if err != nil {			return err		}		if written != len(v) {			return errors.New("Write method didn't write the whole value")		}	}	return w.FinishPacket()}func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {	packet_type, err := r.BeginRead()	if err != nil {		return nil, err	}	struct_buf, err := ioutil.ReadAll(r)	if err != nil {		return nil, err	}	if packet_type != 4 {		return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")	}	offset := 0	results := map[uint8][]byte{}	for true {		rec_type := struct_buf[offset]		if rec_type == preloginTERMINATOR {			break		}		rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])		rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])		value := struct_buf[rec_offset : rec_offset+rec_len]		results[rec_type] = value		offset += 5	}	return results, nil}// OptionFlags2// http://msdn.microsoft.com/en-us/library/dd304019.aspxconst (	fLanguageFatal = 1	fODBC          = 2	fTransBoundary = 4	fCacheConnect  = 8	fIntSecurity   = 0x80)// TypeFlagsconst (	// 4 bits for fSQLType	// 1 bit for fOLEDB	fReadOnlyIntent = 32)type login struct {	TDSVersion     uint32	PacketSize     uint32	ClientProgVer  uint32	ClientPID      uint32	ConnectionID   uint32	OptionFlags1   uint8	OptionFlags2   uint8	TypeFlags      uint8	OptionFlags3   uint8	ClientTimeZone int32	ClientLCID     uint32	HostName       string	UserName       string	Password       string	AppName        string	ServerName     string	CtlIntName     string	Language       string	Database       string	ClientID       [6]byte	SSPI           []byte	AtchDBFile     string	ChangePassword string}type loginHeader struct {	Length               uint32	TDSVersion           uint32	PacketSize           uint32	ClientProgVer        uint32	ClientPID            uint32	ConnectionID         uint32	OptionFlags1         uint8	OptionFlags2         uint8	TypeFlags            uint8	OptionFlags3         uint8	ClientTimeZone       int32	ClientLCID           uint32	HostNameOffset       uint16	HostNameLength       uint16	UserNameOffset       uint16	UserNameLength       uint16	PasswordOffset       uint16	PasswordLength       uint16	AppNameOffset        uint16	AppNameLength        uint16	ServerNameOffset     uint16	ServerNameLength     uint16	ExtensionOffset      uint16	ExtensionLenght      uint16	CtlIntNameOffset     uint16	CtlIntNameLength     uint16	LanguageOffset       uint16	LanguageLength       uint16	DatabaseOffset       uint16	DatabaseLength       uint16	ClientID             [6]byte	SSPIOffset           uint16	SSPILength           uint16	AtchDBFileOffset     uint16	AtchDBFileLength     uint16	ChangePasswordOffset uint16	ChangePasswordLength uint16	SSPILongLength       uint32}// convert Go string to UTF-16 encoded []byte (littleEndian)// done manually rather than using bytes and binary packages// for performance reasonsfunc str2ucs2(s string) []byte {	res := utf16.Encode([]rune(s))	ucs2 := make([]byte, 2*len(res))	for i := 0; i < len(res); i++ {		ucs2[2*i] = byte(res[i])		ucs2[2*i+1] = byte(res[i] >> 8)	}	return ucs2}func ucs22str(s []byte) (string, error) {	if len(s)%2 != 0 {		return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))	}	buf := make([]uint16, len(s)/2)	for i := 0; i < len(s); i += 2 {		buf[i/2] = binary.LittleEndian.Uint16(s[i:])	}	return string(utf16.Decode(buf)), nil}func manglePassword(password string) []byte {	var ucs2password []byte = str2ucs2(password)	for i, ch := range ucs2password {		ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5	}	return ucs2password}// http://msdn.microsoft.com/en-us/library/dd304019.aspxfunc sendLogin(w *tdsBuffer, login login) error {	w.BeginPacket(packLogin7)	hostname := str2ucs2(login.HostName)	username := str2ucs2(login.UserName)	password := manglePassword(login.Password)	appname := str2ucs2(login.AppName)	servername := str2ucs2(login.ServerName)	ctlintname := str2ucs2(login.CtlIntName)	language := str2ucs2(login.Language)	database := str2ucs2(login.Database)	atchdbfile := str2ucs2(login.AtchDBFile)	changepassword := str2ucs2(login.ChangePassword)	hdr := loginHeader{		TDSVersion:           login.TDSVersion,		PacketSize:           login.PacketSize,		ClientProgVer:        login.ClientProgVer,		ClientPID:            login.ClientPID,		ConnectionID:         login.ConnectionID,		OptionFlags1:         login.OptionFlags1,		OptionFlags2:         login.OptionFlags2,		TypeFlags:            login.TypeFlags,		OptionFlags3:         login.OptionFlags3,		ClientTimeZone:       login.ClientTimeZone,		ClientLCID:           login.ClientLCID,		HostNameLength:       uint16(utf8.RuneCountInString(login.HostName)),		UserNameLength:       uint16(utf8.RuneCountInString(login.UserName)),		PasswordLength:       uint16(utf8.RuneCountInString(login.Password)),		AppNameLength:        uint16(utf8.RuneCountInString(login.AppName)),		ServerNameLength:     uint16(utf8.RuneCountInString(login.ServerName)),		CtlIntNameLength:     uint16(utf8.RuneCountInString(login.CtlIntName)),		LanguageLength:       uint16(utf8.RuneCountInString(login.Language)),		DatabaseLength:       uint16(utf8.RuneCountInString(login.Database)),		ClientID:             login.ClientID,		SSPILength:           uint16(len(login.SSPI)),		AtchDBFileLength:     uint16(utf8.RuneCountInString(login.AtchDBFile)),		ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),	}	offset := uint16(binary.Size(hdr))	hdr.HostNameOffset = offset	offset += uint16(len(hostname))	hdr.UserNameOffset = offset	offset += uint16(len(username))	hdr.PasswordOffset = offset	offset += uint16(len(password))	hdr.AppNameOffset = offset	offset += uint16(len(appname))	hdr.ServerNameOffset = offset	offset += uint16(len(servername))	hdr.CtlIntNameOffset = offset	offset += uint16(len(ctlintname))	hdr.LanguageOffset = offset	offset += uint16(len(language))	hdr.DatabaseOffset = offset	offset += uint16(len(database))	hdr.SSPIOffset = offset	offset += uint16(len(login.SSPI))	hdr.AtchDBFileOffset = offset	offset += uint16(len(atchdbfile))	hdr.ChangePasswordOffset = offset	offset += uint16(len(changepassword))	hdr.Length = uint32(offset)	var err error	err = binary.Write(w, binary.LittleEndian, &hdr)	if err != nil {		return err	}	_, err = w.Write(hostname)	if err != nil {		return err	}	_, err = w.Write(username)	if err != nil {		return err	}	_, err = w.Write(password)	if err != nil {		return err	}	_, err = w.Write(appname)	if err != nil {		return err	}	_, err = w.Write(servername)	if err != nil {		return err	}	_, err = w.Write(ctlintname)	if err != nil {		return err	}	_, err = w.Write(language)	if err != nil {		return err	}	_, err = w.Write(database)	if err != nil {		return err	}	_, err = w.Write(login.SSPI)	if err != nil {		return err	}	_, err = w.Write(atchdbfile)	if err != nil {		return err	}	_, err = w.Write(changepassword)	if err != nil {		return err	}	return w.FinishPacket()}func readUcs2(r io.Reader, numchars int) (res string, err error) {	buf := make([]byte, numchars*2)	_, err = io.ReadFull(r, buf)	if err != nil {		return "", err	}	return ucs22str(buf)}func readUsVarChar(r io.Reader) (res string, err error) {	var numchars uint16	err = binary.Read(r, binary.LittleEndian, &numchars)	if err != nil {		return "", err	}	return readUcs2(r, int(numchars))}func writeUsVarChar(w io.Writer, s string) (err error) {	buf := str2ucs2(s)	var numchars int = len(buf) / 2	if numchars > 0xffff {		panic("invalid size for US_VARCHAR")	}	err = binary.Write(w, binary.LittleEndian, uint16(numchars))	if err != nil {		return	}	_, err = w.Write(buf)	return}func readBVarChar(r io.Reader) (res string, err error) {	var numchars uint8	err = binary.Read(r, binary.LittleEndian, &numchars)	if err != nil {		return "", err	}	return readUcs2(r, int(numchars))}func writeBVarChar(w io.Writer, s string) (err error) {	buf := str2ucs2(s)	var numchars int = len(buf) / 2	if numchars > 0xff {		panic("invalid size for B_VARCHAR")	}	err = binary.Write(w, binary.LittleEndian, uint8(numchars))	if err != nil {		return	}	_, err = w.Write(buf)	return}func readBVarByte(r io.Reader) (res []byte, err error) {	var length uint8	err = binary.Read(r, binary.LittleEndian, &length)	if err != nil {		return	}	res = make([]byte, length)	_, err = io.ReadFull(r, res)	return}func readUshort(r io.Reader) (res uint16, err error) {	err = binary.Read(r, binary.LittleEndian, &res)	return}func readByte(r io.Reader) (res byte, err error) {	var b [1]byte	_, err = r.Read(b[:])	res = b[0]	return}// Packet Data Stream Headers// http://msdn.microsoft.com/en-us/library/dd304953.aspxtype headerStruct struct {	hdrtype uint16	data    []byte}const (	dataStmHdrQueryNotif    = 1 // query notifications	dataStmHdrTransDescr    = 2 // MARS transaction descriptor (required)	dataStmHdrTraceActivity = 3)// Query Notifications Header// http://msdn.microsoft.com/en-us/library/dd304949.aspxtype queryNotifHdr struct {	notifyId      string	ssbDeployment string	notifyTimeout uint32}func (hdr queryNotifHdr) pack() (res []byte) {	notifyId := str2ucs2(hdr.notifyId)	ssbDeployment := str2ucs2(hdr.ssbDeployment)	res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)	b := res	binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))	b = b[2:]	copy(b, notifyId)	b = b[len(notifyId):]	binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))	b = b[2:]	copy(b, ssbDeployment)	b = b[len(ssbDeployment):]	binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)	return res}// MARS Transaction Descriptor Header// http://msdn.microsoft.com/en-us/library/dd340515.aspxtype transDescrHdr struct {	transDescr        uint64 // transaction descriptor returned from ENVCHANGE	outstandingReqCnt uint32 // outstanding request count}func (hdr transDescrHdr) pack() (res []byte) {	res = make([]byte, 8+4)	binary.LittleEndian.PutUint64(res, hdr.transDescr)	binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)	return res}func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {	// calculatint total length	var totallen uint32 = 4	for _, hdr := range headers {		totallen += 4 + 2 + uint32(len(hdr.data))	}	// writing	err = binary.Write(w, binary.LittleEndian, totallen)	if err != nil {		return err	}	for _, hdr := range headers {		var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))		err = binary.Write(w, binary.LittleEndian, headerlen)		if err != nil {			return err		}		err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)		if err != nil {			return err		}		_, err = w.Write(hdr.data)		if err != nil {			return err		}	}	return nil}func sendSqlBatch72(buf *tdsBuffer,	sqltext string,	headers []headerStruct) (err error) {	buf.BeginPacket(packSQLBatch)	if err = writeAllHeaders(buf, headers); err != nil {		return	}	_, err = buf.Write(str2ucs2(sqltext))	if err != nil {		return	}	return buf.FinishPacket()}// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspxfunc sendAttention(buf *tdsBuffer) error {	buf.BeginPacket(packAttention)	return buf.FinishPacket()}type connectParams struct {	logFlags               uint64	port                   uint64	host                   string	instance               string	database               string	user                   string	password               string	dial_timeout           time.Duration	conn_timeout           time.Duration	keepAlive              time.Duration	encrypt                bool	disableEncryption      bool	trustServerCertificate bool	certificate            string	hostInCertificate      string	serverSPN              string	workstation            string	appname                string	typeFlags              uint8	failOverPartner        string	failOverPort           uint64}func splitConnectionString(dsn string) (res map[string]string) {	res = map[string]string{}	parts := strings.Split(dsn, ";")	for _, part := range parts {		if len(part) == 0 {			continue		}		lst := strings.SplitN(part, "=", 2)		name := strings.TrimSpace(strings.ToLower(lst[0]))		if len(name) == 0 {			continue		}		var value string = ""		if len(lst) > 1 {			value = strings.TrimSpace(lst[1])		}		res[name] = value	}	return res}// Splits a URL in the ODBC formatfunc splitConnectionStringOdbc(dsn string) (map[string]string, error) {	res := map[string]string{}	type parserState int	const (		// Before the start of a key		parserStateBeforeKey parserState = iota		// Inside a key		parserStateKey		// Beginning of a value. May be bare or braced		parserStateBeginValue		// Inside a bare value		parserStateBareValue		// Inside a braced value		parserStateBracedValue		// A closing brace inside a braced value.		// May be the end of the value or an escaped closing brace, depending on the next character		parserStateBracedValueClosingBrace		// After a value. Next character should be a semi-colon or whitespace.		parserStateEndValue	)	var state = parserStateBeforeKey	var key string	var value string	for i, c := range dsn {		switch state {		case parserStateBeforeKey:			switch {			case c == '=':				return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)			case !unicode.IsSpace(c) && c != ';':				state = parserStateKey				key += string(c)			}		case parserStateKey:			switch c {			case '=':				key = normalizeOdbcKey(key)				if len(key) == 0 {					return res, fmt.Errorf("Unexpected end of key at index %d.", i)				}				state = parserStateBeginValue			case ';':				// Key without value				key = normalizeOdbcKey(key)				if len(key) == 0 {					return res, fmt.Errorf("Unexpected end of key at index %d.", i)				}				res[key] = value				key = ""				value = ""				state = parserStateBeforeKey			default:				key += string(c)			}		case parserStateBeginValue:			switch {			case c == '{':				state = parserStateBracedValue			case c == ';':				// Empty value				res[key] = value				key = ""				state = parserStateBeforeKey			case unicode.IsSpace(c):				// Ignore whitespace			default:				state = parserStateBareValue				value += string(c)			}		case parserStateBareValue:			if c == ';' {				res[key] = strings.TrimRightFunc(value, unicode.IsSpace)				key = ""				value = ""				state = parserStateBeforeKey			} else {				value += string(c)			}		case parserStateBracedValue:			if c == '}' {				state = parserStateBracedValueClosingBrace			} else {				value += string(c)			}		case parserStateBracedValueClosingBrace:			if c == '}' {				// Escaped closing brace				value += string(c)				state = parserStateBracedValue				continue			}			// End of braced value			res[key] = value			key = ""			value = ""			// This character is the first character past the end,			// so it needs to be parsed like the parserStateEndValue state.			state = parserStateEndValue			switch {			case c == ';':				state = parserStateBeforeKey			case unicode.IsSpace(c):				// Ignore whitespace			default:				return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)			}		case parserStateEndValue:			switch {			case c == ';':				state = parserStateBeforeKey			case unicode.IsSpace(c):				// Ignore whitespace			default:				return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)			}		}	}	switch state {	case parserStateBeforeKey: // Okay	case parserStateKey: // Unfinished key. Treat as key without value.		key = normalizeOdbcKey(key)		if len(key) == 0 {			return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))		}		res[key] = value	case parserStateBeginValue: // Empty value		res[key] = value	case parserStateBareValue:		res[key] = strings.TrimRightFunc(value, unicode.IsSpace)	case parserStateBracedValue:		return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))	case parserStateBracedValueClosingBrace: // End of braced value		res[key] = value	case parserStateEndValue: // Okay	}	return res, nil}// Normalizes the given string as an ODBC-format keyfunc normalizeOdbcKey(s string) string {	return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))}// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=valuefunc splitConnectionStringURL(dsn string) (map[string]string, error) {	res := map[string]string{}	u, err := url.Parse(dsn)	if err != nil {		return res, err	}	if u.Scheme != "sqlserver" {		return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)	}	if u.User != nil {		res["user id"] = u.User.Username()		p, exists := u.User.Password()		if exists {			res["password"] = p		}	}	host, port, err := net.SplitHostPort(u.Host)	if err != nil {		host = u.Host	}	if len(u.Path) > 0 {		res["server"] = host + "\\" + u.Path[1:]	} else {		res["server"] = host	}	if len(port) > 0 {		res["port"] = port	}	query := u.Query()	for k, v := range query {		if len(v) > 1 {			return res, fmt.Errorf("key %s provided more than once", k)		}		res[k] = v[0]	}	return res, nil}func parseConnectParams(dsn string) (connectParams, error) {	var p connectParams	var params map[string]string	if strings.HasPrefix(dsn, "odbc:") {		parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])		if err != nil {			return p, err		}		params = parameters	} else if strings.HasPrefix(dsn, "sqlserver://") {		parameters, err := splitConnectionStringURL(dsn)		if err != nil {			return p, err		}		params = parameters	} else {		params = splitConnectionString(dsn)	}	strlog, ok := params["log"]	if ok {		var err error		p.logFlags, err = strconv.ParseUint(strlog, 10, 0)		if err != nil {			return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())		}	}	server := params["server"]	parts := strings.SplitN(server, "\\", 2)	p.host = parts[0]	if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {		p.host = "localhost"	}	if len(parts) > 1 {		p.instance = parts[1]	}	p.database = params["database"]	p.user = params["user id"]	p.password = params["password"]	p.port = 1433	strport, ok := params["port"]	if ok {		var err error		p.port, err = strconv.ParseUint(strport, 0, 16)		if err != nil {			f := "Invalid tcp port '%v': %v"			return p, fmt.Errorf(f, strport, err.Error())		}	}	// https://msdn.microsoft.com/en-us/library/dd341108.aspx	p.dial_timeout = 15 * time.Second	p.conn_timeout = 30 * time.Second	strconntimeout, ok := params["connection timeout"]	if ok {		timeout, err := strconv.ParseUint(strconntimeout, 0, 16)		if err != nil {			f := "Invalid connection timeout '%v': %v"			return p, fmt.Errorf(f, strconntimeout, err.Error())		}		p.conn_timeout = time.Duration(timeout) * time.Second	}	strdialtimeout, ok := params["dial timeout"]	if ok {		timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)		if err != nil {			f := "Invalid dial timeout '%v': %v"			return p, fmt.Errorf(f, strdialtimeout, err.Error())		}		p.dial_timeout = time.Duration(timeout) * time.Second	}	// default keep alive should be 30 seconds according to spec:	// https://msdn.microsoft.com/en-us/library/dd341108.aspx	p.keepAlive = 30 * time.Second	keepAlive, ok := params["keepalive"]	if ok {		timeout, err := strconv.ParseUint(keepAlive, 0, 16)		if err != nil {			f := "Invalid keepAlive value '%s': %s"			return p, fmt.Errorf(f, keepAlive, err.Error())		}		p.keepAlive = time.Duration(timeout) * time.Second	}	encrypt, ok := params["encrypt"]	if ok {		if strings.ToUpper(encrypt) == "DISABLE" {			p.disableEncryption = true		} else {			var err error			p.encrypt, err = strconv.ParseBool(encrypt)			if err != nil {				f := "Invalid encrypt '%s': %s"				return p, fmt.Errorf(f, encrypt, err.Error())			}		}	} else {		p.trustServerCertificate = true	}	trust, ok := params["trustservercertificate"]	if ok {		var err error		p.trustServerCertificate, err = strconv.ParseBool(trust)		if err != nil {			f := "Invalid trust server certificate '%s': %s"			return p, fmt.Errorf(f, trust, err.Error())		}	}	p.certificate = params["certificate"]	p.hostInCertificate, ok = params["hostnameincertificate"]	if !ok {		p.hostInCertificate = p.host	}	serverSPN, ok := params["serverspn"]	if ok {		p.serverSPN = serverSPN	} else {		p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)	}	workstation, ok := params["workstation id"]	if ok {		p.workstation = workstation	} else {		workstation, err := os.Hostname()		if err == nil {			p.workstation = workstation		}	}	appname, ok := params["app name"]	if !ok {		appname = "go-mssqldb"	}	p.appname = appname	appintent, ok := params["applicationintent"]	if ok {		if appintent == "ReadOnly" {			p.typeFlags |= fReadOnlyIntent		}	}	failOverPartner, ok := params["failoverpartner"]	if ok {		p.failOverPartner = failOverPartner	}	failOverPort, ok := params["failoverport"]	if ok {		var err error		p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)		if err != nil {			f := "Invalid tcp port '%v': %v"			return p, fmt.Errorf(f, failOverPort, err.Error())		}	}	return p, nil}type Auth interface {	InitialBytes() ([]byte, error)	NextBytes([]byte) ([]byte, error)	Free()}// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a// list of IP addresses.  So if there is more than one, try them all and// use the first one that allows a connection.func dialConnection(p connectParams) (conn net.Conn, err error) {	var ips []net.IP	ips, err = net.LookupIP(p.host)	if err != nil {		ip := net.ParseIP(p.host)		if ip == nil {			return nil, err		}		ips = []net.IP{ip}	}	if len(ips) == 1 {		d := createDialer(&p)		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))		conn, err = d.Dial(addr)	} else {		//Try Dials in parallel to avoid waiting for timeouts.		connChan := make(chan net.Conn, len(ips))		errChan := make(chan error, len(ips))		portStr := strconv.Itoa(int(p.port))		for _, ip := range ips {			go func(ip net.IP) {				d := createDialer(&p)				addr := net.JoinHostPort(ip.String(), portStr)				conn, err := d.Dial(addr)				if err == nil {					connChan <- conn				} else {					errChan <- err				}			}(ip)		}		// Wait for either the *first* successful connection, or all the errors	wait_loop:		for i, _ := range ips {			select {			case conn = <-connChan:				// Got a connection to use, close any others				go func(n int) {					for i := 0; i < n; i++ {						select {						case conn := <-connChan:							conn.Close()						case <-errChan:						}					}				}(len(ips) - i - 1)				// Remove any earlier errors we may have collected				err = nil				break wait_loop			case err = <-errChan:			}		}	}	// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection	if conn == nil {		f := "Unable to open tcp connection with host '%v:%v': %v"		return nil, fmt.Errorf(f, p.host, p.port, err.Error())	}	return conn, err}func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {	res = nil	// if instance is specified use instance resolution service	if p.instance != "" {		p.instance = strings.ToUpper(p.instance)		instances, err := getInstances(p.host)		if err != nil {			f := "Unable to get instances from Sql Server Browser on host %v: %v"			return nil, fmt.Errorf(f, p.host, err.Error())		}		strport, ok := instances[p.instance]["tcp"]		if !ok {			f := "No instance matching '%v' returned from host '%v'"			return nil, fmt.Errorf(f, p.instance, p.host)		}		p.port, err = strconv.ParseUint(strport, 0, 16)		if err != nil {			f := "Invalid tcp port returned from Sql Server Browser '%v': %v"			return nil, fmt.Errorf(f, strport, err.Error())		}	}initiate_connection:	conn, err := dialConnection(p)	if err != nil {		return nil, err	}	toconn := NewTimeoutConn(conn, p.conn_timeout)	outbuf := newTdsBuffer(4096, toconn)	sess := tdsSession{		buf:      outbuf,		log:      log,		logFlags: p.logFlags,	}	instance_buf := []byte(p.instance)	instance_buf = append(instance_buf, 0) // zero terminate instance name	var encrypt byte	if p.disableEncryption {		encrypt = encryptNotSup	} else if p.encrypt {		encrypt = encryptOn	} else {		encrypt = encryptOff	}	fields := map[uint8][]byte{		preloginVERSION:    {0, 0, 0, 0, 0, 0},		preloginENCRYPTION: {encrypt},		preloginINSTOPT:    instance_buf,		preloginTHREADID:   {0, 0, 0, 0},		preloginMARS:       {0}, // MARS disabled	}	err = writePrelogin(outbuf, fields)	if err != nil {		return nil, err	}	fields, err = readPrelogin(outbuf)	if err != nil {		return nil, err	}	encryptBytes, ok := fields[preloginENCRYPTION]	if !ok {		return nil, fmt.Errorf("Encrypt negotiation failed")	}	encrypt = encryptBytes[0]	if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {		return nil, fmt.Errorf("Server does not support encryption")	}	if encrypt != encryptNotSup {		var config tls.Config		if p.certificate != "" {			pem, err := ioutil.ReadFile(p.certificate)			if err != nil {				f := "Cannot read certificate '%s': %s"				return nil, fmt.Errorf(f, p.certificate, err.Error())			}			certs := x509.NewCertPool()			certs.AppendCertsFromPEM(pem)			config.RootCAs = certs		}		if p.trustServerCertificate {			config.InsecureSkipVerify = true		}		config.ServerName = p.hostInCertificate		outbuf.transport = conn		toconn.buf = outbuf		tlsConn := tls.Client(toconn, &config)		err = tlsConn.Handshake()		toconn.buf = nil		outbuf.transport = tlsConn		if err != nil {			f := "TLS Handshake failed: %s"			return nil, fmt.Errorf(f, err.Error())		}		if encrypt == encryptOff {			outbuf.afterFirst = func() {				outbuf.transport = toconn			}		}	}	login := login{		TDSVersion:   verTDS74,		PacketSize:   outbuf.PackageSize(),		Database:     p.database,		OptionFlags2: fODBC, // to get unlimited TEXTSIZE		HostName:     p.workstation,		ServerName:   p.host,		AppName:      p.appname,		TypeFlags:    p.typeFlags,	}	auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)	if auth_ok {		login.SSPI, err = auth.InitialBytes()		if err != nil {			return nil, err		}		login.OptionFlags2 |= fIntSecurity		defer auth.Free()	} else {		login.UserName = p.user		login.Password = p.password	}	err = sendLogin(outbuf, login)	if err != nil {		return nil, err	}	// processing login response	var sspi_msg []bytecontinue_login:	tokchan := make(chan tokenStruct, 5)	go processResponse(context.Background(), &sess, tokchan)	success := false	for tok := range tokchan {		switch token := tok.(type) {		case sspiMsg:			sspi_msg, err = auth.NextBytes(token)			if err != nil {				return nil, err			}		case loginAckStruct:			success = true			sess.loginAck = token		case error:			return nil, fmt.Errorf("Login error: %s", token.Error())		}	}	if sspi_msg != nil {		outbuf.BeginPacket(packSSPIMessage)		_, err = outbuf.Write(sspi_msg)		if err != nil {			return nil, err		}		err = outbuf.FinishPacket()		if err != nil {			return nil, err		}		sspi_msg = nil		goto continue_login	}	if !success {		return nil, fmt.Errorf("Login failed")	}	if sess.routedServer != "" {		toconn.Close()		p.host = sess.routedServer		p.port = uint64(sess.routedPort)		goto initiate_connection	}	return &sess, nil}
 |