Skip to content

Instantly share code, notes, and snippets.

@greycodee
Created March 27, 2022 15:30
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save greycodee/22f98464fece7792a83433a1fba58e2a to your computer and use it in GitHub Desktop.
Save greycodee/22f98464fece7792a83433a1fba58e2a to your computer and use it in GitHub Desktop.
MySQL 协议明文连接代码实现[Go语言]
package main
import (
"crypto/sha1"
"encoding/binary"
"fmt"
"github.com/liushuochen/gotable"
"net"
"strings"
)
func main() {
client := &MySQLClient{
addr: "127.0.0.1:3306",
username: "root",
password: "root",
}
client.init()
defer client.conn.Close()
client.CommandQuery("show databases")
}
type MySQLClient struct {
conn net.Conn
addr string
username string
password string
}
func (m *MySQLClient) init() {
// 连接阶段
handshake := m.startConn()
m.sendHandshakeResponse41(handshake)
}
/*
连接 MySQL 服务器
*/
func (m *MySQLClient) startConn() *HandshakeV10 {
m.conn, _ = net.Dial("tcp",m.addr)
initResp := make([]byte,1024)
readLen, _ := m.conn.Read(initResp)
return ReadHandShakeV10(initResp[:readLen])
}
/*
解析初始握手包 HandShakeV10
*/
func ReadHandShakeV10(data []byte) *HandshakeV10 {
index := 0
var h = &HandshakeV10{}
index+=4
h.ProtocolVersion= int32(data[index])
index++
var serverVersion []byte
for data[index]!=0 {
serverVersion = append(serverVersion,data[index])
index++
}
h.ServerVersion = string(serverVersion)
index++
connectByte := data[index:index+4]
for i :=range connectByte{
h.ConnectionId+=int32(connectByte[i])
}
index+=4
var apdp1 []byte
apdp1Byte := data[index:index+8]
for i := range apdp1Byte {
apdp1 = append(apdp1, apdp1Byte[i])
}
h.AuthPluginDataPart_1 = string(apdp1)
index+=9
// 能力低2位
c_flag_low_1 := strings.Split(fmt.Sprintf("%b\n",data[index+1]),"")
c_flag_low_2 := strings.Split(fmt.Sprintf("%b\n",data[index]),"")
index+=2
// 编码获取
h.CharacterSet = int32(data[index])
index++
// 服务器状态
index+=2
// 能力高2位
c_flag_up_1 := strings.Split(fmt.Sprintf("%b\n",data[index+1]),"")
c_flag_up_2 := strings.Split(fmt.Sprintf("%b\n",data[index]),"")
var capabilityFlags []string
capabilityFlags = append(capabilityFlags,c_flag_up_1...)
capabilityFlags = append(capabilityFlags,c_flag_up_2...)
capabilityFlags = append(capabilityFlags,c_flag_low_1...)
capabilityFlags = append(capabilityFlags,c_flag_low_2...)
index+=2
if strings.EqualFold("1",capabilityFlags[19]){
h.AuthPluginDataLen= int32(data[index])
}
index++
index+=10
if strings.EqualFold("1",capabilityFlags[15]){
p2Len := 13
p2len1 := int(h.AuthPluginDataLen-8)
if p2Len < p2len1 {
p2Len = p2len1
}
h.AuthPluginDataPart_2 = string(data[index:index+p2Len])
index+=p2Len
}
if strings.EqualFold("1",capabilityFlags[19]) {
var authPlugName []byte
for data[index] != 0 {
authPlugName = append(authPlugName,data[index])
index++
}
h.AuthPluginName = string(authPlugName)
}
return h
}
type HandshakeV10 struct {
ProtocolVersion int32 `protobuf:"varint,1,opt,name=protocol_version,json=protocolVersion,proto3" json:"protocol_version,omitempty"`
ServerVersion string `protobuf:"bytes,2,opt,name=server_version,json=serverVersion,proto3" json:"server_version,omitempty"`
ConnectionId int32 `protobuf:"varint,3,opt,name=connection_id,json=connectionId,proto3" json:"connection_id,omitempty"`
AuthPluginDataPart_1 string `protobuf:"bytes,4,opt,name=auth_plugin_data_part_1,json=authPluginDataPart1,proto3" json:"auth_plugin_data_part_1,omitempty"`
CharacterSet int32 `protobuf:"varint,6,opt,name=character_set,json=characterSet,proto3" json:"character_set,omitempty"`
StatusFlags int32 `protobuf:"varint,7,opt,name=status_flags,json=statusFlags,proto3" json:"status_flags,omitempty"`
AuthPluginDataLen int32 `protobuf:"varint,8,opt,name=auth_plugin_data_len,json=authPluginDataLen,proto3" json:"auth_plugin_data_len,omitempty"`
AuthPluginDataPart_2 string `protobuf:"bytes,9,opt,name=auth_plugin_data_part_2,json=authPluginDataPart2,proto3" json:"auth_plugin_data_part_2,omitempty"`
AuthPluginName string `protobuf:"bytes,10,opt,name=auth_plugin_name,json=authPluginName,proto3" json:"auth_plugin_name,omitempty"`
}
/*
解析通用响应数据包 OK_Packet、ERR_Packet、数据集
*/
func (m *MySQLClient) handleResponse() uint8 {
resp := make([]byte,1024)
readLen, _ := m.conn.Read(resp)
data := resp[:readLen]
data = data[4:]
switch data[0] {
case 0x00:
fmt.Println("成功")
return 0x00
case 0xff:
fmt.Println("失败")
return 0xff
default:
parseResultSet(data)
return 0xfe
}
}
func parseResultSet(resp []byte) {
index := 0
fieldLen := resp[0]
index+=1
headRows := make([]string,0)
headIndex := 1
// 读取列数据
for headIndex <= int(fieldLen){
n,l := readColumn(resp,index)
index+=l
headRows = append(headRows, n)
headIndex++
}
table, err := gotable.Create(headRows...)
if err != nil {
fmt.Println("Create table failed: ", err.Error())
return
}
// 读取行内容
for {
// 判断是否是 EOF 数据包
if resp[index+4] == 0xfe{
packLen := 0
for _,v :=range resp[index:index+3]{
packLen+=int(v)
}
if packLen<9 {
break
}
}
rows,ll := readRow(resp,index, int(fieldLen))
table.AddRow(rows)
index+=ll
}
// 打印
fmt.Println(table)
}
func readColumn(data []byte, startIndex int) (name string,length int) {
packLen := data[startIndex:startIndex+3]
for i :=range packLen{
length+=int(packLen[i])
}
length += 4
startIndex+=4
startIndex+=int(data[startIndex]+1)
startIndex+=int(data[startIndex]+1)
startIndex+=int(data[startIndex]+1)
startIndex+=int(data[startIndex]+1)
nameLen := int(data[startIndex])
name = string(data[startIndex+1:startIndex+nameLen+1])
return
}
func readRow(data []byte, startIndex int, fieldNum int) (name []string,length int) {
packLen := data[startIndex:startIndex+3]
for i :=range packLen{
length+=int(packLen[i])
}
length += 4
startIndex+=4
f:=0
for f < fieldNum{
dataLen := 0
// 计算字节数据长度
if data[startIndex] < 0xfb {
// NULL
dataLen = int(data[startIndex])
}else if data[startIndex] == 0xfc {
for _,v := range data[startIndex+1:startIndex+3]{
dataLen+=int(v)
}
}else if data[startIndex] == 0xfd {
for _,v :=range data[startIndex+1:startIndex+5]{
dataLen+=int(v)
}
}else if data[startIndex] == 0xfe {
for _,v :=range data[startIndex+1:startIndex+9]{
dataLen+=int(v)
}
}
name = append(name, string(data[startIndex+1:startIndex+dataLen+1]))
startIndex += dataLen+1
f++
}
return
}
/*
发送初始响应数据包 HandshakeResponse41,包含登陆信息
*/
func (m *MySQLClient) sendHandshakeResponse41(serverResp *HandshakeV10) {
resp := make([]byte,0)
resp = append(resp, Int32ToBytesOfLittle(19833351)...)
resp = append(resp, Int32ToBytesOfLittle(16777215)...)
resp = append(resp, 33)
reserved := make([]byte,23)
resp = append(resp, reserved...)
resp = append(resp, []byte(m.username)...)
resp = append(resp, 0)
resp = append(resp, 20)
resp = append(resp, CalcPassword([]byte(serverResp.AuthPluginDataPart_1+serverResp.AuthPluginDataPart_2)[:20],[]byte(m.password))...)
resp = append(resp, []byte("mysql_native_password")...)
resp = append(resp, 0)
_, _ = m.conn.Write(Pack(resp,1))
flag := m.handleResponse()
if flag == 0xff {
panic("连接失败")
}
return
}
/*
密码 mysql_native_password 加密方法
*/
func CalcPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
/*
CommandQuery
发送命令,并读取数据
*/
func (m *MySQLClient) CommandQuery(sql string) {
resp := make([]byte,0)
resp = append(resp, 3)
resp = append(resp, []byte(sql)...)
_, _ = m.conn.Write(Pack(resp,0))
m.handleResponse()
}
/*
发送数据最终包格式
*/
func Pack(data []byte,seqId uint8) []byte {
p := make([]byte,0)
p = append(p, Int32ToBytesOfLittle(int32(len(data)))[:3]...)
p = append(p, seqId)
p = append(p, data...)
return p
}
/*
Int32ToBytesOfLittle
int32 转换为小端序字节数组
*/
func Int32ToBytesOfLittle(i int32) []byte {
var buf = make([]byte, 4)
binary.LittleEndian.PutUint32(buf, uint32(i))
return buf
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment