websocket.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package net
  2. import (
  3. "errors"
  4. "net"
  5. "net/http"
  6. "strconv"
  7. "golang.org/x/net/websocket"
  8. )
  9. var (
  10. ErrWebsocketListenerClosed = errors.New("websocket listener closed")
  11. )
  12. const (
  13. FrpWebsocketPath = "/~!frp"
  14. )
  15. type WebsocketListener struct {
  16. ln net.Listener
  17. acceptCh chan net.Conn
  18. server *http.Server
  19. httpMutex *http.ServeMux
  20. }
  21. // NewWebsocketListener to handle websocket connections
  22. // ln: tcp listener for websocket connections
  23. func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
  24. wl = &WebsocketListener{
  25. acceptCh: make(chan net.Conn),
  26. }
  27. muxer := http.NewServeMux()
  28. muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
  29. notifyCh := make(chan struct{})
  30. conn := WrapCloseNotifyConn(c, func() {
  31. close(notifyCh)
  32. })
  33. wl.acceptCh <- conn
  34. <-notifyCh
  35. }))
  36. wl.server = &http.Server{
  37. Addr: ln.Addr().String(),
  38. Handler: muxer,
  39. }
  40. go wl.server.Serve(ln)
  41. return
  42. }
  43. func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
  44. tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
  45. if err != nil {
  46. return nil, err
  47. }
  48. l := NewWebsocketListener(tcpLn)
  49. return l, nil
  50. }
  51. func (p *WebsocketListener) Accept() (net.Conn, error) {
  52. c, ok := <-p.acceptCh
  53. if !ok {
  54. return nil, ErrWebsocketListenerClosed
  55. }
  56. return c, nil
  57. }
  58. func (p *WebsocketListener) Close() error {
  59. return p.server.Close()
  60. }
  61. func (p *WebsocketListener) Addr() net.Addr {
  62. return p.ln.Addr()
  63. }