使用Golang实现内网穿透

什么是内网穿透

内网穿透,即NAT穿透,网络连接时术语,计算机是局域网内时,外网与内网的计算机节点需要连接通信,有时就会出现不支持内网穿透。就是说映射端口,能让外网的电脑找到处于内网的电脑。

内网穿透原理图

内网穿透.png

使用Golang实现

服务端代码 server.go

package main

import (
	"io"
	"log"
	"net"
	"sync"
	"time"
)

// 内网穿透服务端
// tranfer port 8001
// worker port 8002

// 启动后保存来自客户端接收的连接
// 当用户访问ip:8002时,接收到这个连接的后,然后给客户端发送消息请求建立一个转发的连接
// 然后使用io.Copy来传输两个连接的数据

// h 为心跳信息
// n 为建立新连接信息

var (
	tranfersaddr = ":8001"
	workeraddr   = ":8002"
	cacheconn    *net.TCPConn      // 保持一条连接
	conns        chan *net.TCPConn // 新建连接池
)

func main() {
	conns = make(chan *net.TCPConn)

	tranferaddr, err := net.ResolveTCPAddr("tcp", tranfersaddr)
	if err != nil {
		log.Fatalln(err)
	}
	tlis, err := net.ListenTCP("tcp", tranferaddr)
	if err != nil {
		log.Fatalln(err)
	}
	log.Println("listen tranfer addr: ", tranfersaddr)

	go recvhb(tlis)

	// 监听worker地址,如果有新的连接就向客户端发送新建立连接的消息
	// 然后把这两个连接绑定(io.Copy)
	log.Println("listen worker addr: ", workeraddr)
	workeraddr, err := net.ResolveTCPAddr("tcp", workeraddr)
	if err != nil {
		log.Fatalln(err)
	}
	wlis, err := net.ListenTCP("tcp", workeraddr)
	if err != nil {
		log.Fatalln(err)
	}
	for {
		conn, err := wlis.AcceptTCP()
		if err != nil {
			log.Println(err)
			continue
		}
		conn.SetKeepAlive(true)
		go joinconns2(conn)
	}
}

// 接收客户端的
func recvhb(tlis *net.TCPListener) {
	for {
		conn, err := tlis.AcceptTCP()
		if err != nil {
			log.Println(err)
			break
		}

		log.Println("recv conn from ", conn.RemoteAddr())
		go recvconn(conn)
	}
}

func recvconn(conn *net.TCPConn) {
	var hbdata = make([]byte, 1024)
	for {
		_, err := conn.Read(hbdata)
		if err == io.EOF {
			log.Println("read over")
			break
		}
		if err != nil {
			log.Printf("read data failed: %+v\n", err)
			//conn.Close()
			break
		}
		switch string(hbdata[0]) {
		case "f":
			// 首次连接
			cacheconn = conn
			log.Println("recv first conn")
			// 不需要退出
		case "n":
			// 新建连接
			//将连接放入conn
			log.Println("recv new conn", conn.RemoteAddr())
			select {
			case conns <- conn:
			default:
				log.Println("can not save conn")
			}
			// 直接return 返回
			// 这个连接后续的数据不在这里处理
			return
		case "h":
			// 心跳连接
			// log.Printf("recv heart from %s\n", conn.RemoteAddr())
			// 不需要退出
		}
	}
}

// 分别建立到转发服务器、本地服务的连接
func joinconns2(wconn *net.TCPConn) {
	var closeconn = make(chan struct{})
	once := sync.Once{}
	f := func(dst *net.TCPConn, src *net.TCPConn) {
		var buf = make([]byte,1024)
		for {
			select {
			case <-closeconn:
				dst.Close()
				src.Close()
				return
			default:
				n, err := src.Read(buf)
				if n > 0 {
					dst.Write(buf[:n])
					continue
				}
				if err != nil {
					if err == io.EOF {
						log.Println("recv over")
					}else {
						log.Println(err)
					}
					once.Do(func() {
						close(closeconn)
					})
				}

			}

		}

	}
	// 新建立连接
	cacheconn.Write([]byte("n"))
	select {
	case tconn := <-conns:
		go f(wconn, tconn)
		go f(tconn, wconn)
	case <-time.After(time.Second):
		log.Println("can not get conn")
		wconn.Close()
	}
}



客户端 client.go

package main

import (
	"io"
	"log"
	"net"
	"sync"
	"time"
)

// 内网穿透客户端
// tranfer port 8001
// local port 8000

// 启动后建立一个对转发服务器连接

// 连接建立后需要定时的给转发服务器发送心跳信息,表明自已存活,并且还要响应转发服务器建立新连接的请求

// h 为心跳信息
// n 为建立新配对连接数据
// f 为首次简历连接时需要发送的数据

var (
	tranfercaddr = "127.0.0.1:8001"
	localaddr    = "127.0.0.1:8000"
)

func main() {

	tcpaddr, err := net.ResolveTCPAddr("tcp", tranfercaddr)
	if err != nil {
		log.Fatalln(err)
	}

	conn, err := net.DialTCP("tcp", nil, tcpaddr)
	if err != nil {
		log.Fatalln(err)
	}
	log.Println("success conn addr ", tranfercaddr)
	// 首次建立连接发送f
	_, err = conn.Write([]byte("f"))
	if err != nil {
		log.Fatalln(err)
	}
	conn.SetKeepAlive(true)
	log.Printf("conn tranfer server %s success\n", tranfercaddr)
	// 心跳信息
	go sendhb(conn)

	var recv = make([]byte, 65535)
	for {
		n, err := conn.Read(recv)
		if err != nil && err == io.EOF {
			continue
		}
		if err != nil {
			log.Printf("read data failed: %+v\n", err)
			break
		}
		if n == 1 && string(recv[0]) == "n" {
			// create new conn
			// create conn to local server
			log.Printf("start create new conn")
			go joinconns()
		}
	}

}

func sendhb(conn *net.TCPConn) {
	for {
		_, err := conn.Write([]byte("h"))
		if err != nil {
			log.Printf("Send hearbeat failed: %+v\n", err)
			break
		}
		time.Sleep(time.Second)
	}
}

// 分别建立到转发服务器、本地服务的连接
func joinconns() {
	var closeconn = make(chan struct{})
	once := sync.Once{}
	f := func(dst *net.TCPConn, src *net.TCPConn) {
		log.Println("start read data")
		var buf = make([]byte,1024)
		for {
			select {
			case <-closeconn:
				dst.Close()
				src.Close()
				return
			default:
				n, err := src.Read(buf)
				if n > 0 {
					dst.Write(buf[:n])
					continue
				}
				if err != nil {
					if err == io.EOF {
						log.Println("recv over")
					}else {
						log.Println(err)
					}
					once.Do(func() {
						close(closeconn)
					})
				}

			}

		}
	}
	tranferconn, err := newconn(tranfercaddr)
	if err != nil {
		log.Printf("new conn failed: %+v\n", err)
		return
	}
	log.Println("write new conn")
	// 使用首次建立的连接发送标志
	tranferconn.Write([]byte("n"))

	// localconn
	localconn, err := newconn(localaddr)
	if err != nil {
		log.Printf("new conn failed: %+v\n", err)
		return
	}

	go f(localconn, tranferconn)
	go f(tranferconn, localconn)
}

func newconn(addr string) (*net.TCPConn, error) {
	tcpaddr, err := net.ResolveTCPAddr("tcp", addr)
	if err != nil {
		return nil, err
	}

	conn, err := net.DialTCP("tcp", nil, tcpaddr)
	if err != nil {
		return nil, err
	}
	return conn, nil
}

启动内网穿透服务