websocket实现

编写基本的httpserver

启动一个基本的httpserver,提供两个接口,一个index返回主页,另一个是就是我们自定义的websocket协议接口

@main.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package main

import (
"fmt"
"log"

"github.com/brewlin/net-protocol/pkg/logging"
"github.com/brewlin/net-protocol/protocol/application/http"
"github.com/brewlin/net-protocol/protocol/application/websocket"
)

func init() {
logging.Setup()
}
func main() {
serv := http.NewHTTP("tap1", "192.168.1.0/24", "192.168.1.1", "9502")
serv.HandleFunc("/ws", echo)

serv.HandleFunc("/", func(request *http.Request, response *http.Response) {
response.End("hello")
})
fmt.Println("@main: server is start ip:192.168.1.1 port:9502 ")
serv.ListenAndServ()
}

//websocket处理器
func echo(r *http.Request, w *http.Response) {
fmt.Println("got http request ; start to upgrade websocket protocol....")
//协议升级 c *websocket.Conn
c, err := websocket.Upgrade(r, w)
if err != nil {
//升级协议失败,直接return 交由http处理响应
fmt.Println("Upgrade error:", err)
return
}
defer c.Close()
//循环处理数据,接受数据,然后返回
for {
message, err := c.ReadData()
if err != nil {
log.Println("read:", err)
break
}
fmt.Println("recv client msg:", string(message))
// c.SendData(message )
c.SendData([]byte("hello"))
}
}

echo 接口接受http请求并进行升级我们的websocket

页面如下

index

自定义的webscoket upgrade进行升级

根据之前的协议分析,我知道握手的过程其实就是检查 HTTP 请求头部字段的过程,值得注意的一点就是需要针对客户端发送的 Sec-WebSocket-Key 生成一个正确的 Sec-WebSocket-Accept 只。关于生成的 Sec-WebSocket-Accpet 的实现,可以参考之前的分析。握手过程的具体代码如下:

@upgrade.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package websocket

import(
"net/http"
"net"
"errors"
"log"
"bufio"
)

func Upgrade(w http.ResponseWriter,r *http.Request)(c *Conn,err error){
//是否是Get方法
if r.Method != "GET" {
http.Error(w,http.StatusText(http.StatusMethodNotAllowed),http.StatusMethodNotAllowed)
return nil,errors.New("websocket:method not GET")
}
//检查 Sec-WebSocket-Version 版本
if values := r.Header["Sec-Websocket-Version"];len(values) == 0 || values[0] != "13" {
http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
return nil,errors.New("websocket:version != 13")
}

//检查Connection 和 Upgrade
if !tokenListContainsValue(r.Header,"Connection","upgrade") {
http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
return nil,errors.New("websocket:could not find connection header with token 'upgrade'")
}
if !tokenListContainsValue(r.Header,"Upgrade","websocket") {
http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
return nil,errors.New("websocket:could not find connection header with token 'websocket'")
}


//计算Sec-Websocket-Accept的值
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
http.Error(w,http.StatusText(http.StatusBadRequest),http.StatusBadRequest)
return nil,errors.New("websocket:key missing or blank")
}

var (
netConn net.Conn
br *bufio.Reader
)
h,ok := w.(http.Hijacker)
if !ok {
http.Error(w,http.StatusText(http.StatusInternalServerError),http.StatusInternalServerError)
return nil,errors.New("websocket:response dose not implement http.Hijacker")
}
var rw *bufio.ReadWriter
//接管当前tcp连接,阻止内置http接管连接
netConn,rw,err = h.Hijack()
if err != nil {
http.Error(w,http.StatusText(http.StatusInternalServerError),http.StatusInternalServerError)
return nil,err
}

br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil,errors.New("websocket:client send data before hanshake is complete")
}
// 构造握手成功后返回的 response
p := []byte{}
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n\r\n"...)
//返回repson 但不关闭连接
if _,err = netConn.Write(p);err != nil {
netConn.Close()
return nil,err
}
//升级为websocket
log.Println("Upgrade http to websocket successfully")
conn := newConn(netConn)
return conn,nil
}

握手过程的代码比较直观,就不多做解释了。到这里 WebSocket 的实现就基本完成了,可以看到有了之前的各种约定,我们实现 WebSocket 协议也是比较简单的。

封装的websocket结构体和对应的方法

@conn.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package websocket

import (
"fmt"
"encoding/binary"
"log"
"errors"
"net"
)

const (
/*
* 是否是最后一个数据帧
* Fin Rsv1 Rsv2 Rsv3 Opcode
* 1 0 0 0 0 0 0 0 => 128
*/
finalBit = 1 << 7

/*
* 是否需要掩码处理
* Mask payload-len 第一位mask表示是否需要进行掩码处理 后面
* 7位表示数据包长度 1.0-125 表示长度 2.126 后面需要扩展2 字节 16bit
* 3.127则扩展8bit
* 1 0 0 0 0 0 0 0 => 128
*/
maskBit = 1 << 7

/*
* 文本帧类型
* 0 0 0 0 0 0 0 1
*/
TextMessage = 1
/*
* 关闭数据帧类型
* 0 0 0 0 1 0 0 0
*/
CloseMessage = 8
)

//websocket 连接
type Conn struct {
writeBuf []byte
maskKey [4]byte
conn net.Conn
}
func newConn(conn net.Conn)*Conn{
return &Conn{conn:conn}
}
func (c *Conn)Close(){
c.conn.Close()
}

//发送数据
func (c *Conn)SendData(data []byte){
length := len(data)
c.writeBuf = make([]byte,10 + length)

//数据开始和结束位置
payloadStart := 2
/**
*数据帧的第一个字节,不支持且只能发送文本类型数据
*finalBit 1 0 0 0 0 0 0 0
* |
*Text 0 0 0 0 0 0 0 1
* => 1 0 0 0 0 0 0 1
*/
c.writeBuf[0] = byte(TextMessage) | finalBit
fmt.Printf("1 bit:%b\n",c.writeBuf[0])

//数据帧第二个字节,服务器发送的数据不需要进行掩码处理
switch{
//大于2字节的长度
case length >= 1 << 16 ://65536
//c.writeBuf[1] = byte(0x00) | 127 // 127
c.writeBuf[1] = byte(127) // 127
//大端写入64位
binary.BigEndian.PutUint64(c.writeBuf[payloadStart:],uint64(length))
//需要8byte来存储数据长度
payloadStart += 8
case length > 125:
//c.writeBuf[1] = byte(0x00) | 126
c.writeBuf[1] = byte(126)
binary.BigEndian.PutUint16(c.writeBuf[payloadStart:],uint16(length))
payloadStart += 2
default:
//c.writeBuf[1] = byte(0x00) | byte(length)
c.writeBuf[1] = byte(length)
}
fmt.Printf("2 bit:%b\n",c.writeBuf[1])

copy(c.writeBuf[payloadStart:],data[:])
c.conn.Write(c.writeBuf[:payloadStart+length])
}

//读取数据
func (c *Conn)ReadData()(data []byte,err error){
var b [8]byte
//读取数据帧的前两个字节
if _,err := c.conn.Read(b[:2]); err != nil {
return nil,err
}
//开始解析第一个字节 是否还有后续数据帧
final := b[0] & finalBit != 0
fmt.Printf("read data 1 bit :%b\n",b[0])
//不支持数据分片
if !final {
log.Println("Recived fragemented frame,not support")
return nil,errors.New("not suppeort fragmented message")
}

//数据帧类型
/*
*1 0 0 0 0 0 0 1
* &
*0 0 0 0 1 1 1 1
*0 0 0 0 0 0 0 1
* => 1 这样就可以直接获取到类型了
*/
frameType := int(b[0] & 0xf)
//如果 关闭类型,则关闭连接
if frameType == CloseMessage {
c.conn.Close()
log.Println("Recived closed message,connection will be closed")
return nil,errors.New("recived closed message")
}
//只实现了文本格式的传输,编码utf-8
if frameType != TextMessage {
return nil,errors.New("only support text message")
}
//检查数据帧是否被掩码处理
//maskBit => 1 0 0 0 0 0 0 0 任何与他 要么为0 要么为 128
mask := b[1] & maskBit != 0
//数据长度
payloadLen := int64(b[1] & 0x7F)//0 1 1 1 1 1 1 1 1 127
dataLen := int64(payloadLen)
//根据payload length 判断数据的真实长度
switch payloadLen {
case 126://扩展2字节
if _,err := c.conn.Read(b[:2]);err != nil {
return nil,err
}
//获取扩展二字节的真实数据长度
dataLen = int64(binary.BigEndian.Uint16(b[:2]))
case 127 :
if _,err := c.conn.Read(b[:8]);err != nil {
return nil,err
}
dataLen = int64(binary.BigEndian.Uint64(b[:8]))
}

log.Printf("Read data length :%d,payload length %d",payloadLen,dataLen)
//读取mask key
if mask {//如果需要掩码处理的话 需要取出key
//maskKey 是 4 字节 32位
if _,err := c.conn.Read(c.maskKey[:]);err != nil {
return nil ,err
}
}
//读取数据内容
p := make([]byte,dataLen)
if _,err := c.conn.Read(p);err != nil {
return nil,err
}
if mask {
maskBytes(c.maskKey,p)//进行解码
}
return p,nil
}

http 头部检查

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import (
"crypto/sha1"
"encoding/base64"
"strings"
"net/http"
)


var KeyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
//握手阶段使用 加密key返回 进行握手
func computeAcceptKey(challengeKey string)string{
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(KeyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

//解码
func maskBytes(key [4]byte,b []byte){
pos := 0
for i := range b {
b[i] ^= key[pos & 3]
pos ++
}
}

// 检查http 头部字段中是否包含指定的值
func tokenListContainsValue(header http.Header, name string, value string)bool{
for _,v := range header[name] {
for _, s := range strings.Split(v,","){
if strings.EqualFold(value,strings.TrimSpace(s)) {
return true
}
}
}
return false
}