request.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package request
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strconv"
  12. "time"
  13. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  14. libdial "github.com/fatedier/golib/net/dial"
  15. )
  16. type Request struct {
  17. protocol string
  18. // for all protocol
  19. addr string
  20. port int
  21. body []byte
  22. timeout time.Duration
  23. // for http or https
  24. method string
  25. host string
  26. path string
  27. headers map[string]string
  28. tlsConfig *tls.Config
  29. proxyURL string
  30. }
  31. func New() *Request {
  32. return &Request{
  33. protocol: "tcp",
  34. addr: "127.0.0.1",
  35. method: "GET",
  36. path: "/",
  37. }
  38. }
  39. func (r *Request) Protocol(protocol string) *Request {
  40. r.protocol = protocol
  41. return r
  42. }
  43. func (r *Request) TCP() *Request {
  44. r.protocol = "tcp"
  45. return r
  46. }
  47. func (r *Request) UDP() *Request {
  48. r.protocol = "udp"
  49. return r
  50. }
  51. func (r *Request) HTTP() *Request {
  52. r.protocol = "http"
  53. return r
  54. }
  55. func (r *Request) HTTPS() *Request {
  56. r.protocol = "https"
  57. return r
  58. }
  59. func (r *Request) Proxy(url string) *Request {
  60. r.proxyURL = url
  61. return r
  62. }
  63. func (r *Request) Addr(addr string) *Request {
  64. r.addr = addr
  65. return r
  66. }
  67. func (r *Request) Port(port int) *Request {
  68. r.port = port
  69. return r
  70. }
  71. func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
  72. r.method = method
  73. r.host = host
  74. r.path = path
  75. r.headers = headers
  76. return r
  77. }
  78. func (r *Request) HTTPHost(host string) *Request {
  79. r.host = host
  80. return r
  81. }
  82. func (r *Request) HTTPPath(path string) *Request {
  83. r.path = path
  84. return r
  85. }
  86. func (r *Request) HTTPHeaders(headers map[string]string) *Request {
  87. r.headers = headers
  88. return r
  89. }
  90. func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request {
  91. r.tlsConfig = tlsConfig
  92. return r
  93. }
  94. func (r *Request) Timeout(timeout time.Duration) *Request {
  95. r.timeout = timeout
  96. return r
  97. }
  98. func (r *Request) Body(content []byte) *Request {
  99. r.body = content
  100. return r
  101. }
  102. func (r *Request) Do() (*Response, error) {
  103. var (
  104. conn net.Conn
  105. err error
  106. )
  107. addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
  108. // for protocol http and https
  109. if r.protocol == "http" || r.protocol == "https" {
  110. return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path),
  111. r.host, r.headers, r.proxyURL, r.body, r.tlsConfig)
  112. }
  113. // for protocol tcp and udp
  114. if len(r.proxyURL) > 0 {
  115. if r.protocol != "tcp" {
  116. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  117. }
  118. proxyType, proxyAddress, auth, err := libdial.ParseProxyURL(r.proxyURL)
  119. if err != nil {
  120. return nil, fmt.Errorf("parse ProxyURL error: %v", err)
  121. }
  122. conn, err = libdial.Dial(addr, libdial.WithProxy(proxyType, proxyAddress), libdial.WithProxyAuth(auth))
  123. if err != nil {
  124. return nil, err
  125. }
  126. } else {
  127. switch r.protocol {
  128. case "tcp":
  129. conn, err = net.Dial("tcp", addr)
  130. case "udp":
  131. conn, err = net.Dial("udp", addr)
  132. default:
  133. return nil, fmt.Errorf("invalid protocol")
  134. }
  135. if err != nil {
  136. return nil, err
  137. }
  138. }
  139. defer conn.Close()
  140. if r.timeout > 0 {
  141. conn.SetDeadline(time.Now().Add(r.timeout))
  142. }
  143. buf, err := r.sendRequestByConn(conn, r.body)
  144. if err != nil {
  145. return nil, err
  146. }
  147. return &Response{Content: buf}, nil
  148. }
  149. type Response struct {
  150. Code int
  151. Header http.Header
  152. Content []byte
  153. }
  154. func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string,
  155. proxy string, body []byte, tlsConfig *tls.Config,
  156. ) (*Response, error) {
  157. var inBody io.Reader
  158. if len(body) != 0 {
  159. inBody = bytes.NewReader(body)
  160. }
  161. req, err := http.NewRequest(method, urlstr, inBody)
  162. if err != nil {
  163. return nil, err
  164. }
  165. if host != "" {
  166. req.Host = host
  167. }
  168. for k, v := range headers {
  169. req.Header.Set(k, v)
  170. }
  171. tr := &http.Transport{
  172. DialContext: (&net.Dialer{
  173. Timeout: time.Second,
  174. KeepAlive: 30 * time.Second,
  175. DualStack: true,
  176. }).DialContext,
  177. MaxIdleConns: 100,
  178. IdleConnTimeout: 90 * time.Second,
  179. TLSHandshakeTimeout: 10 * time.Second,
  180. ExpectContinueTimeout: 1 * time.Second,
  181. TLSClientConfig: tlsConfig,
  182. }
  183. if len(proxy) != 0 {
  184. tr.Proxy = func(req *http.Request) (*url.URL, error) {
  185. return url.Parse(proxy)
  186. }
  187. }
  188. client := http.Client{Transport: tr}
  189. resp, err := client.Do(req)
  190. if err != nil {
  191. return nil, err
  192. }
  193. ret := &Response{Code: resp.StatusCode, Header: resp.Header}
  194. buf, err := io.ReadAll(resp.Body)
  195. if err != nil {
  196. return nil, err
  197. }
  198. ret.Content = buf
  199. return ret, nil
  200. }
  201. func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  202. _, err := rpc.WriteBytes(c, content)
  203. if err != nil {
  204. return nil, fmt.Errorf("write error: %v", err)
  205. }
  206. var reader io.Reader = c
  207. if r.protocol == "udp" {
  208. reader = bufio.NewReader(c)
  209. }
  210. buf, err := rpc.ReadBytes(reader)
  211. if err != nil {
  212. return nil, fmt.Errorf("read error: %v", err)
  213. }
  214. return buf, nil
  215. }