Commit 2dd96a4b by anun

Initial commit

parents
version: 2
updates:
# Maintain dependencies for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
# Dependencies listed in go.mod
- package-ecosystem: "gomod"
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
# test this goreleaser config with:
# - cd chisel
# - goreleaser --skip-publish --rm-dist --config .github/goreleaser.yml
builds:
- env:
- CGO_ENABLED=0
ldflags:
- -s -w -X github.com/jpillora/chisel/share.BuildVersion={{.Version}}
flags:
- -trimpath
goos:
- linux
- darwin
- windows
goarch:
- 386
- amd64
- arm
- arm64
- ppc64
- ppc64le
- mips
- mipsle
- mips64
- mips64le
- s390x
goarm:
- 5
- 6
- 7
gomips:
- hardfloat
- softfloat
archives:
- format: gz
files:
- none*
release:
draft: true
prerelease: auto
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
on: [push, pull_request]
name: CI
jobs:
# ================
# TEST JOB
# runs on every push and PR
# runs 2x3 times (see matrix)
# ================
test:
name: Test
strategy:
matrix:
go-version: [1.21.x]
platform: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v3
- name: Build
run: go build -v -o /dev/null .
- name: Test
run: go test -v ./...
# ================
# RELEASE JOBS
# runs after a success test
# only runs on push "v*" tag
# ================
release_binaries:
name: Release Binaries
needs: test
if: startsWith(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: goreleaser
if: success()
uses: docker://goreleaser/goreleaser:latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
args: release --config .github/goreleaser.yml
release_docker:
name: Release Docker Images
needs: test
if: startsWith(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: jpillora
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Docker meta
id: docker_meta
uses: docker/metadata-action@v4
with:
images: jpillora/chisel
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
- name: Build and push
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64,linux/ppc64le,linux/386,linux/arm/v7,linux/arm/v6
push: true
tags: ${{ steps.docker_meta.outputs.tags }}
labels: ${{ steps.docker_meta.outputs.labels }}
dist/
*.swp
.idea/
chisel
bin/
release/
tmp/
*.orig
debug
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# build stage
FROM golang:alpine AS build
RUN apk update && apk add git
ADD . /src
WORKDIR /src
ENV CGO_ENABLED 0
RUN go build \
-ldflags "-X github.com/jpillora/chisel/share.BuildVersion=$(git describe --abbrev=0 --tags)" \
-o /tmp/bin
# run stage
FROM scratch
LABEL maintainer="dev@jpillora.com"
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
WORKDIR /app
COPY --from=build /tmp/bin /app/bin
ENTRYPOINT ["/app/bin"]
\ No newline at end of file
MIT License
Copyright (c) 2020 Jaime Pillora <dev@jpillora.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
VERSION=$(shell git describe --abbrev=0 --tags)
BUILD=$(shell git rev-parse HEAD)
DIRBASE=./build
DIR=${DIRBASE}/${VERSION}/${BUILD}/bin
LDFLAGS=-ldflags "-s -w ${XBUILD} -buildid=${BUILD} -X github.com/jpillora/chisel/share.BuildVersion=${VERSION}"
GOFILES=`go list ./...`
GOFILESNOTEST=`go list ./... | grep -v test`
# Make Directory to store executables
$(shell mkdir -p ${DIR})
all:
@goreleaser build --skip-validate --single-target --config .github/goreleaser.yml
freebsd: lint
env CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -trimpath ${LDFLAGS} ${GCFLAGS} ${ASMFLAGS} -o ${DIR}/chisel-freebsd_amd64 .
linux: lint
env CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -trimpath ${LDFLAGS} ${GCFLAGS} ${ASMFLAGS} -o ${DIR}/chisel-linux_amd64 .
windows: lint
env CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -trimpath ${LDFLAGS} ${GCFLAGS} ${ASMFLAGS} -o ${DIR}/chisel-windows_amd64 .
darwin:
env CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -trimpath ${LDFLAGS} ${GCFLAGS} ${ASMFLAGS} -o ${DIR}/chisel-darwin_amd64 .
docker:
@docker build .
dep: ## Get the dependencies
@go get -u github.com/goreleaser/goreleaser
@go get -u github.com/boumenot/gocover-cobertura
@go get -v -d ./...
@go get -u all
@go mod tidy
lint: ## Lint the files
@go fmt ${GOFILES}
@go vet ${GOFILESNOTEST}
test: ## Run unit tests
@go test -coverprofile=${DIR}/coverage.out -race -short ${GOFILESNOTEST}
@go tool cover -html=${DIR}/coverage.out -o ${DIR}/coverage.html
@gocover-cobertura < ${DIR}/coverage.out > ${DIR}/coverage.xml
release: lint test
goreleaser release --config .github/goreleaser.yml
clean:
rm -rf ${DIRBASE}/*
.PHONY: all freebsd linux windows docker dep lint test release clean
\ No newline at end of file
This diff is collapsed. Click to expand it.
package chclient
import (
"context"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gorilla/websocket"
chshare "github.com/jpillora/chisel/share"
"github.com/jpillora/chisel/share/ccrypto"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/chisel/share/tunnel"
"golang.org/x/crypto/ssh"
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"
)
// Config represents a client configuration
type Config struct {
Fingerprint string
Auth string
KeepAlive time.Duration
MaxRetryCount int
MaxRetryInterval time.Duration
Server string
Proxy string
Remotes []string
Headers http.Header
TLS TLSConfig
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
Verbose bool
}
// TLSConfig for a Client
type TLSConfig struct {
SkipVerify bool
CA string
Cert string
Key string
ServerName string
}
// Client represents a client instance
type Client struct {
*cio.Logger
config *Config
computed settings.Config
sshConfig *ssh.ClientConfig
tlsConfig *tls.Config
proxyURL *url.URL
server string
connCount cnet.ConnCount
stop func()
eg *errgroup.Group
tunnel *tunnel.Tunnel
}
// NewClient creates a new client instance
func NewClient(c *Config) (*Client, error) {
//apply default scheme
if !strings.HasPrefix(c.Server, "http") {
c.Server = "http://" + c.Server
}
if c.MaxRetryInterval < time.Second {
c.MaxRetryInterval = 5 * time.Minute
}
u, err := url.Parse(c.Server)
if err != nil {
return nil, err
}
//swap to websockets scheme
u.Scheme = strings.Replace(u.Scheme, "http", "ws", 1)
//apply default port
if !regexp.MustCompile(`:\d+$`).MatchString(u.Host) {
if u.Scheme == "wss" {
u.Host = u.Host + ":443"
} else {
u.Host = u.Host + ":80"
}
}
hasReverse := false
hasSocks := false
hasStdio := false
client := &Client{
Logger: cio.NewLogger("client"),
config: c,
computed: settings.Config{
Version: chshare.BuildVersion,
},
server: u.String(),
tlsConfig: nil,
}
//set default log level
client.Logger.Info = true
//configure tls
if u.Scheme == "wss" {
tc := &tls.Config{}
if c.TLS.ServerName != "" {
tc.ServerName = c.TLS.ServerName
}
//certificate verification config
if c.TLS.SkipVerify {
client.Infof("TLS verification disabled")
tc.InsecureSkipVerify = true
} else if c.TLS.CA != "" {
rootCAs := x509.NewCertPool()
if b, err := ioutil.ReadFile(c.TLS.CA); err != nil {
return nil, fmt.Errorf("Failed to load file: %s", c.TLS.CA)
} else if ok := rootCAs.AppendCertsFromPEM(b); !ok {
return nil, fmt.Errorf("Failed to decode PEM: %s", c.TLS.CA)
} else {
client.Infof("TLS verification using CA %s", c.TLS.CA)
tc.RootCAs = rootCAs
}
}
//provide client cert and key pair for mtls
if c.TLS.Cert != "" && c.TLS.Key != "" {
c, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
if err != nil {
return nil, fmt.Errorf("Error loading client cert and key pair: %v", err)
}
tc.Certificates = []tls.Certificate{c}
} else if c.TLS.Cert != "" || c.TLS.Key != "" {
return nil, fmt.Errorf("Please specify client BOTH cert and key")
}
client.tlsConfig = tc
}
//validate remotes
for _, s := range c.Remotes {
r, err := settings.DecodeRemote(s)
if err != nil {
return nil, fmt.Errorf("Failed to decode remote '%s': %s", s, err)
}
if r.Socks {
hasSocks = true
}
if r.Reverse {
hasReverse = true
}
if r.Stdio {
if hasStdio {
return nil, errors.New("Only one stdio is allowed")
}
hasStdio = true
}
//confirm non-reverse tunnel is available
if !r.Reverse && !r.Stdio && !r.CanListen() {
return nil, fmt.Errorf("Client cannot listen on %s", r.String())
}
client.computed.Remotes = append(client.computed.Remotes, r)
}
//outbound proxy
if p := c.Proxy; p != "" {
client.proxyURL, err = url.Parse(p)
if err != nil {
return nil, fmt.Errorf("Invalid proxy URL (%s)", err)
}
}
//ssh auth and config
user, pass := settings.ParseAuth(c.Auth)
client.sshConfig = &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.Password(pass)},
ClientVersion: "SSH-" + chshare.ProtocolVersion + "-client",
HostKeyCallback: client.verifyServer,
Timeout: settings.EnvDuration("SSH_TIMEOUT", 30*time.Second),
}
//prepare client tunnel
client.tunnel = tunnel.New(tunnel.Config{
Logger: client.Logger,
Inbound: true, //client always accepts inbound
Outbound: hasReverse,
Socks: hasReverse && hasSocks,
KeepAlive: client.config.KeepAlive,
})
return client, nil
}
// Run starts client and blocks while connected
func (c *Client) Run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := c.Start(ctx); err != nil {
return err
}
return c.Wait()
}
func (c *Client) verifyServer(hostname string, remote net.Addr, key ssh.PublicKey) error {
expect := c.config.Fingerprint
if expect == "" {
return nil
}
got := ccrypto.FingerprintKey(key)
_, err := base64.StdEncoding.DecodeString(expect)
if _, ok := err.(base64.CorruptInputError); ok {
c.Logger.Infof("Specified deprecated MD5 fingerprint (%s), please update to the new SHA256 fingerprint: %s", expect, got)
return c.verifyLegacyFingerprint(key)
} else if err != nil {
return fmt.Errorf("Error decoding fingerprint: %w", err)
}
if got != expect {
return fmt.Errorf("Invalid fingerprint (%s)", got)
}
//overwrite with complete fingerprint
c.Infof("Fingerprint %s", got)
return nil
}
// verifyLegacyFingerprint calculates and compares legacy MD5 fingerprints
func (c *Client) verifyLegacyFingerprint(key ssh.PublicKey) error {
bytes := md5.Sum(key.Marshal())
strbytes := make([]string, len(bytes))
for i, b := range bytes {
strbytes[i] = fmt.Sprintf("%02x", b)
}
got := strings.Join(strbytes, ":")
expect := c.config.Fingerprint
if !strings.HasPrefix(got, expect) {
return fmt.Errorf("Invalid fingerprint (%s)", got)
}
return nil
}
// Start client and does not block
func (c *Client) Start(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
c.stop = cancel
eg, ctx := errgroup.WithContext(ctx)
c.eg = eg
via := ""
if c.proxyURL != nil {
via = " via " + c.proxyURL.String()
}
c.Infof("Connecting to %s%s\n", c.server, via)
//connect to chisel server
eg.Go(func() error {
return c.connectionLoop(ctx)
})
//listen sockets
eg.Go(func() error {
clientInbound := c.computed.Remotes.Reversed(false)
if len(clientInbound) == 0 {
return nil
}
return c.tunnel.BindRemotes(ctx, clientInbound)
})
return nil
}
func (c *Client) setProxy(u *url.URL, d *websocket.Dialer) error {
// CONNECT proxy
if !strings.HasPrefix(u.Scheme, "socks") {
d.Proxy = func(*http.Request) (*url.URL, error) {
return u, nil
}
return nil
}
// SOCKS5 proxy
if u.Scheme != "socks" && u.Scheme != "socks5h" {
return fmt.Errorf(
"unsupported socks proxy type: %s:// (only socks5h:// or socks:// is supported)",
u.Scheme,
)
}
var auth *proxy.Auth
if u.User != nil {
pass, _ := u.User.Password()
auth = &proxy.Auth{
User: u.User.Username(),
Password: pass,
}
}
socksDialer, err := proxy.SOCKS5("tcp", u.Host, auth, proxy.Direct)
if err != nil {
return err
}
d.NetDial = socksDialer.Dial
return nil
}
// Wait blocks while the client is running.
func (c *Client) Wait() error {
return c.eg.Wait()
}
// Close manually stops the client
func (c *Client) Close() error {
if c.stop != nil {
c.stop()
}
return nil
}
package chclient
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jpillora/backoff"
chshare "github.com/jpillora/chisel/share"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/cos"
"github.com/jpillora/chisel/share/settings"
"golang.org/x/crypto/ssh"
)
func (c *Client) connectionLoop(ctx context.Context) error {
//connection loop!
b := &backoff.Backoff{Max: c.config.MaxRetryInterval}
for {
connected, err := c.connectionOnce(ctx)
//reset backoff after successful connections
if connected {
b.Reset()
}
//connection error
attempt := int(b.Attempt())
maxAttempt := c.config.MaxRetryCount
//dont print closed-connection errors
if strings.HasSuffix(err.Error(), "use of closed network connection") {
err = io.EOF
}
//show error message and attempt counts (excluding disconnects)
if err != nil && err != io.EOF {
msg := fmt.Sprintf("Connection error: %s", err)
if attempt > 0 {
maxAttemptVal := fmt.Sprint(maxAttempt)
if maxAttempt < 0 {
maxAttemptVal = "unlimited"
}
msg += fmt.Sprintf(" (Attempt: %d/%s)", attempt, maxAttemptVal)
}
c.Infof(msg)
}
//give up?
if maxAttempt >= 0 && attempt >= maxAttempt {
c.Infof("Give up")
break
}
d := b.Duration()
c.Infof("Retrying in %s...", d)
select {
case <-cos.AfterSignal(d):
continue //retry now
case <-ctx.Done():
c.Infof("Cancelled")
return nil
}
}
c.Close()
return nil
}
// connectionOnce connects to the chisel server and blocks
func (c *Client) connectionOnce(ctx context.Context) (connected bool, err error) {
//already closed?
select {
case <-ctx.Done():
return false, errors.New("Cancelled")
default:
//still open
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
//prepare dialer
d := websocket.Dialer{
HandshakeTimeout: settings.EnvDuration("WS_TIMEOUT", 45*time.Second),
Subprotocols: []string{chshare.ProtocolVersion},
TLSClientConfig: c.tlsConfig,
ReadBufferSize: settings.EnvInt("WS_BUFF_SIZE", 0),
WriteBufferSize: settings.EnvInt("WS_BUFF_SIZE", 0),
NetDialContext: c.config.DialContext,
}
//optional proxy
if p := c.proxyURL; p != nil {
if err := c.setProxy(p, &d); err != nil {
return false, err
}
}
wsConn, _, err := d.DialContext(ctx, c.server, c.config.Headers)
if err != nil {
return false, err
}
conn := cnet.NewWebSocketConn(wsConn)
// perform SSH handshake on net.Conn
c.Debugf("Handshaking...")
sshConn, chans, reqs, err := ssh.NewClientConn(conn, "", c.sshConfig)
if err != nil {
e := err.Error()
if strings.Contains(e, "unable to authenticate") {
c.Infof("Authentication failed")
c.Debugf(e)
} else {
c.Infof(e)
}
return false, err
}
defer sshConn.Close()
// chisel client handshake (reverse of server handshake)
// send configuration
c.Debugf("Sending config")
t0 := time.Now()
_, configerr, err := sshConn.SendRequest(
"config",
true,
settings.EncodeConfig(c.computed),
)
if err != nil {
c.Infof("Config verification failed")
return false, err
}
if len(configerr) > 0 {
return false, errors.New(string(configerr))
}
c.Infof("Connected (Latency %s)", time.Since(t0))
//connected, handover ssh connection for tunnel to use, and block
err = c.tunnel.BindSSH(ctx, sshConn, reqs, chans)
c.Infof("Disconnected")
connected = time.Since(t0) > 5*time.Second
return connected, err
}
package chclient
import (
"crypto/elliptic"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/jpillora/chisel/share/ccrypto"
"golang.org/x/crypto/ssh"
)
func TestCustomHeaders(t *testing.T) {
//fake server
wg := sync.WaitGroup{}
wg.Add(1)
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("Foo") != "Bar" {
t.Fatal("expected header Foo to be 'Bar'")
}
wg.Done()
}))
defer server.Close()
//client
headers := http.Header{}
headers.Set("Foo", "Bar")
config := Config{
KeepAlive: time.Second,
MaxRetryInterval: time.Second,
Server: server.URL,
Remotes: []string{"9000"},
Headers: headers,
}
c, err := NewClient(&config)
if err != nil {
log.Fatal(err)
}
go c.Run()
//wait for test to complete
wg.Wait()
c.Close()
}
func TestFallbackLegacyFingerprint(t *testing.T) {
config := Config{
Fingerprint: "a5:32:92:c6:56:7a:9e:61:26:74:1b:81:a6:f5:1b:44",
}
c, err := NewClient(&config)
if err != nil {
t.Fatal(err)
}
r := ccrypto.NewDetermRand([]byte("test123"))
priv, err := ccrypto.GenerateKeyGo119(elliptic.P256(), r)
if err != nil {
t.Fatal(err)
}
pub, err := ssh.NewPublicKey(&priv.PublicKey)
if err != nil {
t.Fatal(err)
}
err = c.verifyServer("", nil, pub)
if err != nil {
t.Fatal(err)
}
}
func TestVerifyLegacyFingerprint(t *testing.T) {
config := Config{
Fingerprint: "a5:32:92:c6:56:7a:9e:61:26:74:1b:81:a6:f5:1b:44",
}
c, err := NewClient(&config)
if err != nil {
t.Fatal(err)
}
r := ccrypto.NewDetermRand([]byte("test123"))
priv, err := ccrypto.GenerateKeyGo119(elliptic.P256(), r)
if err != nil {
t.Fatal(err)
}
pub, err := ssh.NewPublicKey(&priv.PublicKey)
if err != nil {
t.Fatal(err)
}
err = c.verifyLegacyFingerprint(pub)
if err != nil {
t.Fatal(err)
}
}
func TestVerifyFingerprint(t *testing.T) {
config := Config{
Fingerprint: "qmrRoo8MIqePv3jC8+wv49gU6uaFgD3FASQx9V8KdmY=",
}
c, err := NewClient(&config)
if err != nil {
t.Fatal(err)
}
r := ccrypto.NewDetermRand([]byte("test123"))
priv, err := ccrypto.GenerateKeyGo119(elliptic.P256(), r)
if err != nil {
t.Fatal(err)
}
pub, err := ssh.NewPublicKey(&priv.PublicKey)
if err != nil {
t.Fatal(err)
}
err = c.verifyServer("", nil, pub)
if err != nil {
t.Fatal(err)
}
}
FROM jpillora/chisel
ENTRYPOINT ["/app/bin", "server", "--port", "443", "--tls-domain", "chisel.jpillora.com"]
\ No newline at end of file
app = "jp-chisel"
kill_signal = "SIGINT"
kill_timeout = 5
processes = []
[build]
dockerfile = "Flyfile"
[[services]]
internal_port = 443
protocol = "tcp"
[[services.ports]]
port = "443"
\ No newline at end of file
# Reverse Tunneling
> **Use Case**: Host a website on your Raspberry Pi without opening ports on your router.
This guide will show you how to use an internet-facing server (for example, a cloud VPS) as a relay to bounce down TCP traffic on port 80 to your Raspberry Pi.
## Chisel CLI
### Server
Setup a relay server on the VPS to bounce down TCP traffic on port 80:
```bash
#!/bin/bash
# ⬇️ Start Chisel server in Reverse mode
chisel server --reverse \
# ⬇️ Use the include users.json as an authfile
--authfile="./users.json" \
```
The corresponding `authfile` might look like this:
```json
{
"foo:bar": ["0.0.0.0:80"]
}
```
### Client
Setup a chisel client to receive bounced-down traffic and forward it to the webserver running on the Pi:
```bash
#!/bin/bash
chisel client \
# ⬇️ Authenticates user "foo" with password "bar"
--auth="foo:bar" \
# ⬇️ Connects to chisel relay server example.com
# listening on the default ("fallback") port, 8080
example.com \
# ⬇️ Reverse tunnels port 80 on the relay server to
# port 80 on your Pi.
R:80:localhost:80
```
---
## Chisel Container
This guide makes use of Docker and Docker compose to accomplish the same task as the above guide.
### Server
Setup a relay server on the VPS to bounce down TCP traffic on port 80:
```yaml
version: '3'
services:
chisel:
image: jpillora/chisel
restart: unless-stopped
container_name: chisel
# ⬇️ Pass CLI arguments one at a time in an array, as required by Docker compose.
command:
- 'server'
# ⬇️ Use the --key=value syntax, since Docker compose doesn't parse whitespace well.
- '--authfile=/users.json'
- '--reverse'
# ⬇️ Mount the authfile as a Docker volume
volumes:
- './users.json:/users.json'
# ⬇️ Give the container unrestricted access to the Docker host's network
network_mode: host
```
The `authfile` (`users.json`) remains the same as in the non-containerized version - shown again with the username `foo` and password `bar`.
```json
{
"foo:bar": ["0.0.0.0:80"]
}
```
### Client
Setup an instance of the Chisel client on the Pi to receive relayed TCP traffic and feed it to the web server:
```yaml
version: '3'
services:
chisel:
# ⬇️ Delay starting Chisel server until the web server container is started.
depends_on:
- webserver
image: jpillora/chisel
restart: unless-stopped
container_name: 'chisel'
command:
- 'client'
# ⬇️ Use username `foo` and password `bar` to authenticate with Chisel server.
- '--auth=foo:bar'
# ⬇️ Domain & port of Chisel server. Port defaults to 8080 on server, but must be manually set on client.
- 'proxy.example.com:8080'
# ⬇️ Reverse tunnel traffic from the chisel server to the web server container, identified in Docker using DNS by its service name `webserver`.
- 'R:80:webserver:80'
networks:
- internal
# ⬇️ Basic Nginx webserver for demo purposes.
webserver:
image: nginx
restart: unless-stopped
container_name: nginx
networks:
- internal
# ⬇️ Make use of a Docker network called `internal`.
networks:
internal:
```
{
"root:toor": [
""
],
"foo:bar": [
"^0.0.0.0:3000$"
],
"ping:pong": [
"^0.0.0.0:[45]000$",
"^example.com:80$",
"^R:0.0.0.0:7000$"
]
}
module github.com/jpillora/chisel
go 1.21
require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/fsnotify/fsnotify v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jpillora/backoff v1.0.0
github.com/jpillora/requestlog v1.0.0
github.com/jpillora/sizestr v1.0.0
golang.org/x/crypto v0.12.0
golang.org/x/net v0.14.0
golang.org/x/sync v0.3.0
)
require (
github.com/andrew-d/go-termutil v0.0.0-20150726205930-009166a695a2 // indirect
github.com/jpillora/ansi v1.0.3 // indirect
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
)
github.com/andrew-d/go-termutil v0.0.0-20150726205930-009166a695a2 h1:axBiC50cNZOs7ygH5BgQp4N+aYrZ2DNpWZ1KG3VOSOM=
github.com/andrew-d/go-termutil v0.0.0-20150726205930-009166a695a2/go.mod h1:jnzFpU88PccN/tPPhCpnNU8mZphvKxYM9lLNkd8e+os=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jpillora/ansi v1.0.3 h1:nn4Jzti0EmRfDxm7JtEs5LzCbNwd5sv+0aE+LdS9/ZQ=
github.com/jpillora/ansi v1.0.3/go.mod h1:D2tT+6uzJvN1nBVQILYWkIdq7zG+b5gcFN5WI/VyjMY=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/jpillora/requestlog v1.0.0 h1:bg++eJ74T7DYL3DlIpiwknrtfdUA9oP/M4fL+PpqnyA=
github.com/jpillora/requestlog v1.0.0/go.mod h1:HTWQb7QfDc2jtHnWe2XEIEeJB7gJPnVdpNn52HXPvy8=
github.com/jpillora/sizestr v1.0.0 h1:4tr0FLxs1Mtq3TnsLDV+GYUWG7Q26a6s+tV5Zfw2ygw=
github.com/jpillora/sizestr v1.0.0/go.mod h1:bUhLv4ctkknatr6gR42qPxirmd5+ds1u7mzD+MZ33f0=
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce h1:fb190+cK2Xz/dvi9Hv8eCYJYvIGUTN2/KLq1pT6CjEc=
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
This diff is collapsed. Click to expand it.
package chserver
import (
"context"
"errors"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"regexp"
"time"
"github.com/gorilla/websocket"
chshare "github.com/jpillora/chisel/share"
"github.com/jpillora/chisel/share/ccrypto"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/requestlog"
"golang.org/x/crypto/ssh"
)
// Config is the configuration for the chisel service
type Config struct {
KeySeed string
KeyFile string
AuthFile string
Auth string
Proxy string
Socks5 bool
Reverse bool
KeepAlive time.Duration
TLS TLSConfig
}
// Server respresent a chisel service
type Server struct {
*cio.Logger
config *Config
fingerprint string
httpServer *cnet.HTTPServer
reverseProxy *httputil.ReverseProxy
sessCount int32
sessions *settings.Users
sshConfig *ssh.ServerConfig
users *settings.UserIndex
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
ReadBufferSize: settings.EnvInt("WS_BUFF_SIZE", 0),
WriteBufferSize: settings.EnvInt("WS_BUFF_SIZE", 0),
}
// NewServer creates and returns a new chisel server
func NewServer(c *Config) (*Server, error) {
server := &Server{
config: c,
httpServer: cnet.NewHTTPServer(),
Logger: cio.NewLogger("server"),
sessions: settings.NewUsers(),
}
server.Info = true
server.users = settings.NewUserIndex(server.Logger)
if c.AuthFile != "" {
if err := server.users.LoadUsers(c.AuthFile); err != nil {
return nil, err
}
}
if c.Auth != "" {
u := &settings.User{Addrs: []*regexp.Regexp{settings.UserAllowAll}}
u.Name, u.Pass = settings.ParseAuth(c.Auth)
if u.Name != "" {
server.users.AddUser(u)
}
}
var pemBytes []byte
var err error
if c.KeyFile != "" {
var key []byte
if ccrypto.IsChiselKey([]byte(c.KeyFile)) {
key = []byte(c.KeyFile)
} else {
key, err = os.ReadFile(c.KeyFile)
if err != nil {
log.Fatalf("Failed to read key file %s", c.KeyFile)
}
}
pemBytes = key
if ccrypto.IsChiselKey(key) {
pemBytes, err = ccrypto.ChiselKey2PEM(key)
if err != nil {
log.Fatalf("Invalid key %s", string(key))
}
}
} else {
//generate private key (optionally using seed)
pemBytes, err = ccrypto.Seed2PEM(c.KeySeed)
if err != nil {
log.Fatal("Failed to generate key")
}
}
//convert into ssh.PrivateKey
private, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
log.Fatal("Failed to parse key")
}
//fingerprint this key
server.fingerprint = ccrypto.FingerprintKey(private.PublicKey())
//create ssh config
server.sshConfig = &ssh.ServerConfig{
ServerVersion: "SSH-" + chshare.ProtocolVersion + "-server",
PasswordCallback: server.authUser,
}
server.sshConfig.AddHostKey(private)
//setup reverse proxy
if c.Proxy != "" {
u, err := url.Parse(c.Proxy)
if err != nil {
return nil, err
}
if u.Host == "" {
return nil, server.Errorf("Missing protocol (%s)", u)
}
server.reverseProxy = httputil.NewSingleHostReverseProxy(u)
//always use proxy host
server.reverseProxy.Director = func(r *http.Request) {
//enforce origin, keep path
r.URL.Scheme = u.Scheme
r.URL.Host = u.Host
r.Host = u.Host
}
}
//print when reverse tunnelling is enabled
if c.Reverse {
server.Infof("Reverse tunnelling enabled")
}
return server, nil
}
// Run is responsible for starting the chisel service.
// Internally this calls Start then Wait.
func (s *Server) Run(host, port string) error {
if err := s.Start(host, port); err != nil {
return err
}
return s.Wait()
}
// Start is responsible for kicking off the http server
func (s *Server) Start(host, port string) error {
return s.StartContext(context.Background(), host, port)
}
// StartContext is responsible for kicking off the http server,
// and can be closed by cancelling the provided context
func (s *Server) StartContext(ctx context.Context, host, port string) error {
s.Infof("Fingerprint %s", s.fingerprint)
if s.users.Len() > 0 {
s.Infof("User authentication enabled")
}
if s.reverseProxy != nil {
s.Infof("Reverse proxy enabled")
}
l, err := s.listener(host, port)
if err != nil {
return err
}
h := http.Handler(http.HandlerFunc(s.handleClientHandler))
if s.Debug {
o := requestlog.DefaultOptions
o.TrustProxy = true
h = requestlog.WrapWith(h, o)
}
return s.httpServer.GoServe(ctx, l, h)
}
// Wait waits for the http server to close
func (s *Server) Wait() error {
return s.httpServer.Wait()
}
// Close forcibly closes the http server
func (s *Server) Close() error {
return s.httpServer.Close()
}
// GetFingerprint is used to access the server fingerprint
func (s *Server) GetFingerprint() string {
return s.fingerprint
}
// authUser is responsible for validating the ssh user / password combination
func (s *Server) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
// check if user authentication is enabled and if not, allow all
if s.users.Len() == 0 {
return nil, nil
}
// check the user exists and has matching password
n := c.User()
user, found := s.users.Get(n)
if !found || user.Pass != string(password) {
s.Debugf("Login failed for user: %s", n)
return nil, errors.New("Invalid authentication for username: %s")
}
// insert the user session map
// TODO this should probably have a lock on it given the map isn't thread-safe
s.sessions.Set(string(c.SessionID()), user)
return nil, nil
}
// AddUser adds a new user into the server user index
func (s *Server) AddUser(user, pass string, addrs ...string) error {
authorizedAddrs := []*regexp.Regexp{}
for _, addr := range addrs {
authorizedAddr, err := regexp.Compile(addr)
if err != nil {
return err
}
authorizedAddrs = append(authorizedAddrs, authorizedAddr)
}
s.users.AddUser(&settings.User{
Name: user,
Pass: pass,
Addrs: authorizedAddrs,
})
return nil
}
// DeleteUser removes a user from the server user index
func (s *Server) DeleteUser(user string) {
s.users.Del(user)
}
// ResetUsers in the server user index.
// Use nil to remove all.
func (s *Server) ResetUsers(users []*settings.User) {
s.users.Reset(users)
}
package chserver
import (
"net/http"
"strings"
"sync/atomic"
"time"
chshare "github.com/jpillora/chisel/share"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/chisel/share/tunnel"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// handleClientHandler is the main http websocket handler for the chisel server
func (s *Server) handleClientHandler(w http.ResponseWriter, r *http.Request) {
//websockets upgrade AND has chisel prefix
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
protocol := r.Header.Get("Sec-WebSocket-Protocol")
if upgrade == "websocket" {
if protocol == chshare.ProtocolVersion {
s.handleWebsocket(w, r)
return
}
//print into server logs and silently fall-through
s.Infof("ignored client connection using protocol '%s', expected '%s'",
protocol, chshare.ProtocolVersion)
}
//proxy target was provided
if s.reverseProxy != nil {
s.reverseProxy.ServeHTTP(w, r)
return
}
//no proxy defined, provide access to health/version checks
switch r.URL.Path {
case "/health":
w.Write([]byte("OK\n"))
return
case "/version":
w.Write([]byte(chshare.BuildVersion))
return
}
//missing :O
w.WriteHeader(404)
w.Write([]byte("Not found"))
}
// handleWebsocket is responsible for handling the websocket connection
func (s *Server) handleWebsocket(w http.ResponseWriter, req *http.Request) {
id := atomic.AddInt32(&s.sessCount, 1)
l := s.Fork("session#%d", id)
wsConn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
l.Debugf("Failed to upgrade (%s)", err)
return
}
conn := cnet.NewWebSocketConn(wsConn)
// perform SSH handshake on net.Conn
l.Debugf("Handshaking with %s...", req.RemoteAddr)
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.Debugf("Failed to handshake (%s)", err)
return
}
// pull the users from the session map
var user *settings.User
if s.users.Len() > 0 {
sid := string(sshConn.SessionID())
u, ok := s.sessions.Get(sid)
if !ok {
panic("bug in ssh auth handler")
}
user = u
s.sessions.Del(sid)
}
// chisel server handshake (reverse of client handshake)
// verify configuration
l.Debugf("Verifying configuration")
// wait for request, with timeout
var r *ssh.Request
select {
case r = <-reqs:
case <-time.After(settings.EnvDuration("CONFIG_TIMEOUT", 10*time.Second)):
l.Debugf("Timeout waiting for configuration")
sshConn.Close()
return
}
failed := func(err error) {
l.Debugf("Failed: %s", err)
r.Reply(false, []byte(err.Error()))
}
if r.Type != "config" {
failed(s.Errorf("expecting config request"))
return
}
c, err := settings.DecodeConfig(r.Payload)
if err != nil {
failed(s.Errorf("invalid config"))
return
}
//print if client and server versions dont match
if c.Version != chshare.BuildVersion {
v := c.Version
if v == "" {
v = "<unknown>"
}
l.Infof("Client version (%s) differs from server version (%s)",
v, chshare.BuildVersion)
}
//validate remotes
for _, r := range c.Remotes {
//if user is provided, ensure they have
//access to the desired remotes
if user != nil {
addr := r.UserAddr()
if !user.HasAccess(addr) {
failed(s.Errorf("access to '%s' denied", addr))
return
}
}
//confirm reverse tunnels are allowed
if r.Reverse && !s.config.Reverse {
l.Debugf("Denied reverse port forwarding request, please enable --reverse")
failed(s.Errorf("Reverse port forwaring not enabled on server"))
return
}
//confirm reverse tunnel is available
if r.Reverse && !r.CanListen() {
failed(s.Errorf("Server cannot listen on %s", r.String()))
return
}
}
//successfuly validated config!
r.Reply(true, nil)
//tunnel per ssh connection
tunnel := tunnel.New(tunnel.Config{
Logger: l,
Inbound: s.config.Reverse,
Outbound: true, //server always accepts outbound
Socks: s.config.Socks5,
KeepAlive: s.config.KeepAlive,
})
//bind
eg, ctx := errgroup.WithContext(req.Context())
eg.Go(func() error {
//connected, handover ssh connection for tunnel to use, and block
return tunnel.BindSSH(ctx, sshConn, reqs, chans)
})
eg.Go(func() error {
//connected, setup reversed-remotes?
serverInbound := c.Remotes.Reversed(true)
if len(serverInbound) == 0 {
return nil
}
//block
return tunnel.BindRemotes(ctx, serverInbound)
})
err = eg.Wait()
if err != nil && !strings.HasSuffix(err.Error(), "EOF") {
l.Debugf("Closed connection (%s)", err)
} else {
l.Debugf("Closed connection")
}
}
package chserver
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"github.com/jpillora/chisel/share/settings"
"golang.org/x/crypto/acme/autocert"
)
//TLSConfig enables configures TLS
type TLSConfig struct {
Key string
Cert string
Domains []string
CA string
}
func (s *Server) listener(host, port string) (net.Listener, error) {
hasDomains := len(s.config.TLS.Domains) > 0
hasKeyCert := s.config.TLS.Key != "" && s.config.TLS.Cert != ""
if hasDomains && hasKeyCert {
return nil, errors.New("cannot use key/cert and domains")
}
var tlsConf *tls.Config
if hasDomains {
tlsConf = s.tlsLetsEncrypt(s.config.TLS.Domains)
}
extra := ""
if hasKeyCert {
c, err := s.tlsKeyCert(s.config.TLS.Key, s.config.TLS.Cert, s.config.TLS.CA)
if err != nil {
return nil, err
}
tlsConf = c
if port != "443" && hasDomains {
extra = " (WARNING: LetsEncrypt will attempt to connect to your domain on port 443)"
}
}
//tcp listen
l, err := net.Listen("tcp", host+":"+port)
if err != nil {
return nil, err
}
//optionally wrap in tls
proto := "http"
if tlsConf != nil {
proto += "s"
l = tls.NewListener(l, tlsConf)
}
if err == nil {
s.Infof("Listening on %s://%s:%s%s", proto, host, port, extra)
}
return l, nil
}
func (s *Server) tlsLetsEncrypt(domains []string) *tls.Config {
//prepare cert manager
m := &autocert.Manager{
Prompt: func(tosURL string) bool {
s.Infof("Accepting LetsEncrypt TOS and fetching certificate...")
return true
},
Email: settings.Env("LE_EMAIL"),
HostPolicy: autocert.HostWhitelist(domains...),
}
//configure file cache
c := settings.Env("LE_CACHE")
if c == "" {
h := os.Getenv("HOME")
if h == "" {
if u, err := user.Current(); err == nil {
h = u.HomeDir
}
}
c = filepath.Join(h, ".cache", "chisel")
}
if c != "-" {
s.Infof("LetsEncrypt cache directory %s", c)
m.Cache = autocert.DirCache(c)
}
//return lets-encrypt tls config
return m.TLSConfig()
}
func (s *Server) tlsKeyCert(key, cert string, ca string) (*tls.Config, error) {
keypair, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return nil, err
}
//file based tls config using tls defaults
c := &tls.Config{
Certificates: []tls.Certificate{keypair},
}
//mTLS requires server's CA
if ca != "" {
if err := addCA(ca, c); err != nil {
return nil, err
}
s.Infof("Loaded CA path: %s", ca)
}
return c, nil
}
func addCA(ca string, c *tls.Config) error {
fileInfo, err := os.Stat(ca)
if err != nil {
return err
}
clientCAPool := x509.NewCertPool()
if fileInfo.IsDir() {
//this is a directory holding CA bundle files
files, err := ioutil.ReadDir(ca)
if err != nil {
return err
}
//add all cert files from path
for _, file := range files {
f := file.Name()
if err := addPEMFile(filepath.Join(ca, f), clientCAPool); err != nil {
return err
}
}
} else {
//this is a CA bundle file
if err := addPEMFile(ca, clientCAPool); err != nil {
return err
}
}
//set client CAs and enable cert verification
c.ClientCAs = clientCAPool
c.ClientAuth = tls.RequireAndVerifyClientCert
return nil
}
func addPEMFile(path string, pool *x509.CertPool) error {
content, err := ioutil.ReadFile(path)
if err != nil {
return err
}
if !pool.AppendCertsFromPEM(content) {
return errors.New("Fail to load certificates from : " + path)
}
return nil
}
package ccrypto
// Deterministic crypto.Reader
// overview: half the result is used as the output
// [a|...] -> sha512(a) -> [b|output] -> sha512(b)
import (
"crypto/sha512"
"io"
)
const DetermRandIter = 2048
func NewDetermRand(seed []byte) io.Reader {
var out []byte
//strengthen seed
var next = seed
for i := 0; i < DetermRandIter; i++ {
next, out = hash(next)
}
return &determRand{
next: next,
out: out,
}
}
type determRand struct {
next, out []byte
}
func (d *determRand) Read(b []byte) (int, error) {
n := 0
l := len(b)
for n < l {
next, out := hash(d.next)
n += copy(b[n:], out)
d.next = next
}
return n, nil
}
func hash(input []byte) (next []byte, output []byte) {
nextout := sha512.Sum512(input)
return nextout[:sha512.Size/2], nextout[sha512.Size/2:]
}
package ccrypto
import (
"crypto/ecdsa"
"crypto/elliptic"
"io"
"math/big"
)
var one = new(big.Int).SetInt64(1)
// This function is copied from ecdsa.GenerateKey() of Go 1.19
func GenerateKeyGo119(c elliptic.Curve, rand io.Reader) (*ecdsa.PrivateKey, error) {
k, err := randFieldElement(c, rand)
if err != nil {
return nil, err
}
priv := new(ecdsa.PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
// This function is copied from Go 1.19
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
params := c.Params()
// Note that for P-521 this will actually be 63 bits more than the order, as
// division rounds down, but the extra bit is inconsequential.
b := make([]byte, params.N.BitLen()/8+8)
_, err = io.ReadFull(rand, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
package ccrypto
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"os"
"golang.org/x/crypto/ssh"
)
// GenerateKey generates a PEM key
func GenerateKey(seed string) ([]byte, error) {
return Seed2PEM(seed)
}
// GenerateKeyFile generates an ChiselKey
func GenerateKeyFile(keyFilePath, seed string) error {
chiselKey, err := seed2ChiselKey(seed)
if err != nil {
return err
}
if keyFilePath == "-" {
fmt.Print(string(chiselKey))
return nil
}
return os.WriteFile(keyFilePath, chiselKey, 0600)
}
// FingerprintKey calculates the SHA256 hash of an SSH public key
func FingerprintKey(k ssh.PublicKey) string {
bytes := sha256.Sum256(k.Marshal())
return base64.StdEncoding.EncodeToString(bytes[:])
}
package ccrypto
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"strings"
)
const ChiselKeyPrefix = "ck-"
// Relations between entities:
//
// .............> PEM <...........
// . ^ .
// . | .
// . | .
// Seed -------> PrivateKey .
// . ^ .
// . | .
// . V .
// ..........> ChiselKey .........
func Seed2PEM(seed string) ([]byte, error) {
privateKey, err := seed2PrivateKey(seed)
if err != nil {
return nil, err
}
return privateKey2PEM(privateKey)
}
func seed2ChiselKey(seed string) ([]byte, error) {
privateKey, err := seed2PrivateKey(seed)
if err != nil {
return nil, err
}
return privateKey2ChiselKey(privateKey)
}
func seed2PrivateKey(seed string) (*ecdsa.PrivateKey, error) {
if seed == "" {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
} else {
return GenerateKeyGo119(elliptic.P256(), NewDetermRand([]byte(seed)))
}
}
func privateKey2ChiselKey(privateKey *ecdsa.PrivateKey) ([]byte, error) {
b, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, err
}
encodedPrivateKey := make([]byte, base64.RawStdEncoding.EncodedLen(len(b)))
base64.RawStdEncoding.Encode(encodedPrivateKey, b)
return append([]byte(ChiselKeyPrefix), encodedPrivateKey...), nil
}
func privateKey2PEM(privateKey *ecdsa.PrivateKey) ([]byte, error) {
b, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}), nil
}
func chiselKey2PrivateKey(chiselKey []byte) (*ecdsa.PrivateKey, error) {
rawChiselKey := chiselKey[len(ChiselKeyPrefix):]
decodedPrivateKey := make([]byte, base64.RawStdEncoding.DecodedLen(len(rawChiselKey)))
_, err := base64.RawStdEncoding.Decode(decodedPrivateKey, rawChiselKey)
if err != nil {
return nil, err
}
return x509.ParseECPrivateKey(decodedPrivateKey)
}
func ChiselKey2PEM(chiselKey []byte) ([]byte, error) {
privateKey, err := chiselKey2PrivateKey(chiselKey)
if err == nil {
return privateKey2PEM(privateKey)
}
return nil, err
}
func IsChiselKey(chiselKey []byte) bool {
return strings.HasPrefix(string(chiselKey), ChiselKeyPrefix)
}
package cio
import (
"fmt"
"log"
"os"
)
//Logger is pkg/log Logger with prefixing and 2 log levels
type Logger struct {
Info, Debug bool
//internal
prefix string
logger *log.Logger
info, debug *bool
}
func NewLogger(prefix string) *Logger {
return NewLoggerFlag(prefix, log.Ldate|log.Ltime)
}
func NewLoggerFlag(prefix string, flag int) *Logger {
l := &Logger{
prefix: prefix,
logger: log.New(os.Stderr, "", flag),
Info: false,
Debug: false,
}
return l
}
func (l *Logger) Infof(f string, args ...interface{}) {
if l.IsInfo() {
l.logger.Printf(l.prefix+": "+f, args...)
}
}
func (l *Logger) Debugf(f string, args ...interface{}) {
if l.IsDebug() {
l.logger.Printf(l.prefix+": "+f, args...)
}
}
func (l *Logger) Errorf(f string, args ...interface{}) error {
return fmt.Errorf(l.prefix+": "+f, args...)
}
func (l *Logger) Fork(prefix string, args ...interface{}) *Logger {
//slip the parent prefix at the front
args = append([]interface{}{l.prefix}, args...)
ll := NewLogger(fmt.Sprintf("%s: "+prefix, args...))
//store link to parent settings too
ll.Info = l.Info
if l.info != nil {
ll.info = l.info
} else {
ll.info = &l.Info
}
ll.Debug = l.Debug
if l.debug != nil {
ll.debug = l.debug
} else {
ll.debug = &l.Debug
}
return ll
}
func (l *Logger) Prefix() string {
return l.prefix
}
func (l *Logger) IsInfo() bool {
return l.Info || (l.info != nil && *l.info)
}
func (l *Logger) IsDebug() bool {
return l.Debug || (l.debug != nil && *l.debug)
}
package cio
import (
"io"
"log"
"sync"
)
func Pipe(src io.ReadWriteCloser, dst io.ReadWriteCloser) (int64, int64) {
var sent, received int64
var wg sync.WaitGroup
var o sync.Once
close := func() {
src.Close()
dst.Close()
}
wg.Add(2)
go func() {
received, _ = io.Copy(src, dst)
o.Do(close)
wg.Done()
}()
go func() {
sent, _ = io.Copy(dst, src)
o.Do(close)
wg.Done()
}()
wg.Wait()
return sent, received
}
const vis = false
type pipeVisPrinter struct {
name string
}
func (p pipeVisPrinter) Write(b []byte) (int, error) {
log.Printf(">>> %s: %x", p.name, b)
return len(b), nil
}
func pipeVis(name string, r io.Reader) io.Reader {
if vis {
return io.TeeReader(r, pipeVisPrinter{name})
}
return r
}
package cio
import (
"io"
"io/ioutil"
"os"
)
//Stdio as a ReadWriteCloser
var Stdio = &struct {
io.ReadCloser
io.Writer
}{
ioutil.NopCloser(os.Stdin),
os.Stdout,
}
package cnet
import (
"io"
"net"
"time"
)
type rwcConn struct {
io.ReadWriteCloser
buff []byte
}
//NewRWCConn converts a RWC into a net.Conn
func NewRWCConn(rwc io.ReadWriteCloser) net.Conn {
c := rwcConn{
ReadWriteCloser: rwc,
}
return &c
}
func (c *rwcConn) LocalAddr() net.Addr {
return c
}
func (c *rwcConn) RemoteAddr() net.Addr {
return c
}
func (c *rwcConn) Network() string {
return "tcp"
}
func (c *rwcConn) String() string {
return ""
}
func (c *rwcConn) SetDeadline(t time.Time) error {
return nil //no-op
}
func (c *rwcConn) SetReadDeadline(t time.Time) error {
return nil //no-op
}
func (c *rwcConn) SetWriteDeadline(t time.Time) error {
return nil //no-op
}
package cnet
import (
"net"
"time"
"github.com/gorilla/websocket"
)
type wsConn struct {
*websocket.Conn
buff []byte
}
//NewWebSocketConn converts a websocket.Conn into a net.Conn
func NewWebSocketConn(websocketConn *websocket.Conn) net.Conn {
c := wsConn{
Conn: websocketConn,
}
return &c
}
//Read is not threadsafe though thats okay since there
//should never be more than one reader
func (c *wsConn) Read(dst []byte) (int, error) {
ldst := len(dst)
//use buffer or read new message
var src []byte
if len(c.buff) > 0 {
src = c.buff
c.buff = nil
} else if _, msg, err := c.Conn.ReadMessage(); err == nil {
src = msg
} else {
return 0, err
}
//copy src->dest
var n int
if len(src) > ldst {
//copy as much as possible of src into dst
n = copy(dst, src[:ldst])
//copy remainder into buffer
r := src[ldst:]
lr := len(r)
c.buff = make([]byte, lr)
copy(c.buff, r)
} else {
//copy all of src into dst
n = copy(dst, src)
}
//return bytes copied
return n, nil
}
func (c *wsConn) Write(b []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
return 0, err
}
n := len(b)
return n, nil
}
func (c *wsConn) SetDeadline(t time.Time) error {
if err := c.Conn.SetReadDeadline(t); err != nil {
return err
}
return c.Conn.SetWriteDeadline(t)
}
package cnet
import (
"fmt"
"sync/atomic"
)
//ConnCount is a connection counter
type ConnCount struct {
count int32
open int32
}
func (c *ConnCount) New() int32 {
return atomic.AddInt32(&c.count, 1)
}
func (c *ConnCount) Open() {
atomic.AddInt32(&c.open, 1)
}
func (c *ConnCount) Close() {
atomic.AddInt32(&c.open, -1)
}
func (c *ConnCount) String() string {
return fmt.Sprintf("[%d/%d]", atomic.LoadInt32(&c.open), atomic.LoadInt32(&c.count))
}
package cnet
import (
"context"
"errors"
"net"
"net/http"
"sync"
"golang.org/x/sync/errgroup"
)
//HTTPServer extends net/http Server and
//adds graceful shutdowns
type HTTPServer struct {
*http.Server
waiterMux sync.Mutex
waiter *errgroup.Group
listenErr error
}
//NewHTTPServer creates a new HTTPServer
func NewHTTPServer() *HTTPServer {
return &HTTPServer{
Server: &http.Server{},
}
}
func (h *HTTPServer) GoListenAndServe(addr string, handler http.Handler) error {
return h.GoListenAndServeContext(context.Background(), addr, handler)
}
func (h *HTTPServer) GoListenAndServeContext(ctx context.Context, addr string, handler http.Handler) error {
if ctx == nil {
return errors.New("ctx must be set")
}
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return h.GoServe(ctx, l, handler)
}
func (h *HTTPServer) GoServe(ctx context.Context, l net.Listener, handler http.Handler) error {
if ctx == nil {
return errors.New("ctx must be set")
}
h.waiterMux.Lock()
defer h.waiterMux.Unlock()
h.Handler = handler
h.waiter, ctx = errgroup.WithContext(ctx)
h.waiter.Go(func() error {
return h.Serve(l)
})
go func() {
<-ctx.Done()
h.Close()
}()
return nil
}
func (h *HTTPServer) Close() error {
h.waiterMux.Lock()
defer h.waiterMux.Unlock()
if h.waiter == nil {
return errors.New("not started yet")
}
return h.Server.Close()
}
func (h *HTTPServer) Wait() error {
h.waiterMux.Lock()
unset := h.waiter == nil
h.waiterMux.Unlock()
if unset {
return errors.New("not started yet")
}
h.waiterMux.Lock()
wait := h.waiter.Wait
h.waiterMux.Unlock()
err := wait()
if err == http.ErrServerClosed {
err = nil //success
}
return err
}
package cnet
import (
"io"
"net"
"sync/atomic"
"time"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/sizestr"
)
//NewMeter to measure readers/writers
func NewMeter(l *cio.Logger) *Meter {
return &Meter{l: l}
}
//Meter can be inserted in the path or
//of a reader or writer to measure the
//throughput
type Meter struct {
//meter state
sent, recv int64
//print state
l *cio.Logger
printing uint32
last int64
lsent, lrecv int64
}
func (m *Meter) print() {
//move out of the read/write path asap
if atomic.CompareAndSwapUint32(&m.printing, 0, 1) {
go m.goprint()
}
}
func (m *Meter) goprint() {
time.Sleep(time.Second)
//snapshot
s := atomic.LoadInt64(&m.sent)
r := atomic.LoadInt64(&m.recv)
//compute speed
curr := time.Now().UnixNano()
last := atomic.LoadInt64(&m.last)
dt := time.Duration(curr-last) * time.Nanosecond
ls := atomic.LoadInt64(&m.lsent)
lr := atomic.LoadInt64(&m.lrecv)
//DEBUG
// m.l.Infof("%s = %d(%d-%d), %d(%d-%d)", dt, s-ls, s, ls, r-lr, r, lr)
//scale to per second V=D/T
sps := int64(float64(s-ls) / float64(dt) * float64(time.Second))
rps := int64(float64(r-lr) / float64(dt) * float64(time.Second))
if last > 0 && (sps != 0 || rps != 0) {
m.l.Debugf("write %s/s read %s/s", sizestr.ToString(sps), sizestr.ToString(rps))
}
//record last printed
atomic.StoreInt64(&m.lsent, s)
atomic.StoreInt64(&m.lrecv, r)
//done
atomic.StoreInt64(&m.last, curr)
atomic.StoreUint32(&m.printing, 0)
}
//TeeReader inserts Meter into the read path
//if the linked logger is in debug mode,
//otherwise this is a no-op
func (m *Meter) TeeReader(r io.Reader) io.Reader {
if m.l.IsDebug() {
return &meterReader{m, r}
}
return r
}
type meterReader struct {
*Meter
inner io.Reader
}
func (m *meterReader) Read(p []byte) (n int, err error) {
n, err = m.inner.Read(p)
atomic.AddInt64(&m.recv, int64(n))
m.Meter.print()
return
}
//TeeWriter inserts Meter into the write path
//if the linked logger is in debug mode,
//otherwise this is a no-op
func (m *Meter) TeeWriter(w io.Writer) io.Writer {
if m.l.IsDebug() {
return &meterWriter{m, w}
}
return w
}
type meterWriter struct {
*Meter
inner io.Writer
}
func (m *meterWriter) Write(p []byte) (n int, err error) {
n, err = m.inner.Write(p)
atomic.AddInt64(&m.sent, int64(n))
m.Meter.print()
return
}
//MeterConn inserts Meter into the connection path
//if the linked logger is in debug mode,
//otherwise this is a no-op
func MeterConn(l *cio.Logger, conn net.Conn) net.Conn {
m := NewMeter(l)
return &meterConn{
mread: m.TeeReader(conn),
mwrite: m.TeeWriter(conn),
Conn: conn,
}
}
type meterConn struct {
mread io.Reader
mwrite io.Writer
net.Conn
}
func (m *meterConn) Read(p []byte) (n int, err error) {
return m.mread.Read(p)
}
func (m *meterConn) Write(p []byte) (n int, err error) {
return m.mwrite.Write(p)
}
//MeterRWC inserts Meter into the RWC path
//if the linked logger is in debug mode,
//otherwise this is a no-op
func MeterRWC(l *cio.Logger, rwc io.ReadWriteCloser) io.ReadWriteCloser {
m := NewMeter(l)
return &struct {
io.Reader
io.Writer
io.Closer
}{
Reader: m.TeeReader(rwc),
Writer: m.TeeWriter(rwc),
Closer: rwc,
}
}
package chshare
//this file exists to maintain backwards compatibility
import (
"github.com/jpillora/chisel/share/ccrypto"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/cos"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/chisel/share/tunnel"
)
const (
DetermRandIter = ccrypto.DetermRandIter
)
type (
Config = settings.Config
Remote = settings.Remote
Remotes = settings.Remotes
User = settings.User
Users = settings.Users
UserIndex = settings.UserIndex
HTTPServer = cnet.HTTPServer
ConnStats = cnet.ConnCount
Logger = cio.Logger
TCPProxy = tunnel.Proxy
)
var (
NewDetermRand = ccrypto.NewDetermRand
GenerateKey = ccrypto.GenerateKey
FingerprintKey = ccrypto.FingerprintKey
Pipe = cio.Pipe
NewLoggerFlag = cio.NewLoggerFlag
NewLogger = cio.NewLogger
Stdio = cio.Stdio
DecodeConfig = settings.DecodeConfig
DecodeRemote = settings.DecodeRemote
NewUsers = settings.NewUsers
NewUserIndex = settings.NewUserIndex
UserAllowAll = settings.UserAllowAll
ParseAuth = settings.ParseAuth
NewRWCConn = cnet.NewRWCConn
NewWebSocketConn = cnet.NewWebSocketConn
NewHTTPServer = cnet.NewHTTPServer
GoStats = cos.GoStats
SleepSignal = cos.SleepSignal
NewTCPProxy = tunnel.NewProxy
)
//EncodeConfig old version
func EncodeConfig(c *settings.Config) ([]byte, error) {
return settings.EncodeConfig(*c), nil
}
package cos
import (
"context"
"os"
"os/signal"
"time"
)
//InterruptContext returns a context which is
//cancelled on OS Interrupt
func InterruptContext() context.Context {
ctx, cancel := context.WithCancel(context.Background())
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt) //windows compatible?
<-sig
signal.Stop(sig)
cancel()
}()
return ctx
}
//SleepSignal sleeps for the given duration,
//or until a SIGHUP is received
func SleepSignal(d time.Duration) {
<-AfterSignal(d)
}
// +build pprof
package cos
import (
"log"
"net/http"
_ "net/http/pprof" //import http profiler api
)
func init() {
go func() {
log.Fatal(http.ListenAndServe("localhost:6060", nil))
}()
log.Printf("[pprof] listening on 6060")
}
//+build !windows
package cos
import (
"log"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/jpillora/sizestr"
)
//GoStats prints statistics to
//stdout on SIGUSR2 (posix-only)
func GoStats() {
//silence complaints from windows
const SIGUSR2 = syscall.Signal(0x1f)
time.Sleep(time.Second)
c := make(chan os.Signal, 1)
signal.Notify(c, SIGUSR2)
for range c {
memStats := runtime.MemStats{}
runtime.ReadMemStats(&memStats)
log.Printf("recieved SIGUSR2, go-routines: %d, go-memory-usage: %s",
runtime.NumGoroutine(),
sizestr.ToString(int64(memStats.Alloc)))
}
}
//AfterSignal returns a channel which will be closed
//after the given duration or until a SIGHUP is received
func AfterSignal(d time.Duration) <-chan struct{} {
ch := make(chan struct{})
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP)
select {
case <-time.After(d):
case <-sig:
}
signal.Stop(sig)
close(ch)
}()
return ch
}
//+build windows
package cos
import (
"time"
)
func GoStats() {
//noop
}
func AfterSignal(d time.Duration) <-chan struct{} {
ch := make(chan struct{})
go func() {
<-time.After(d)
close(ch)
}()
return ch
}
package settings
import (
"encoding/json"
"fmt"
)
type Config struct {
Version string
Remotes
}
func DecodeConfig(b []byte) (*Config, error) {
c := &Config{}
err := json.Unmarshal(b, c)
if err != nil {
return nil, fmt.Errorf("Invalid JSON config")
}
return c, nil
}
func EncodeConfig(c Config) []byte {
//Config doesn't have types that can fail to marshal
b, _ := json.Marshal(c)
return b
}
package settings
import (
"os"
"strconv"
"strings"
"time"
)
// Env returns a chisel environment variable
func Env(name string) string {
return os.Getenv("CHISEL_" + name)
}
// EnvInt returns an integer using an environment variable, with a default fallback
func EnvInt(name string, def int) int {
if n, err := strconv.Atoi(Env(name)); err == nil {
return n
}
return def
}
// EnvDuration returns a duration using an environment variable, with a default fallback
func EnvDuration(name string, def time.Duration) time.Duration {
if n, err := time.ParseDuration(Env(name)); err == nil {
return n
}
return def
}
// EnvBool returns a boolean using an environment variable
func EnvBool(name string) bool {
v := Env(name)
return v == "1" || strings.ToLower(v) == "true"
}
package settings
import (
"errors"
"net"
"net/url"
"regexp"
"strconv"
"strings"
)
// short-hand conversions (see remote_test)
// 3000 ->
// local 127.0.0.1:3000
// remote 127.0.0.1:3000
// foobar.com:3000 ->
// local 127.0.0.1:3000
// remote foobar.com:3000
// 3000:google.com:80 ->
// local 127.0.0.1:3000
// remote google.com:80
// 192.168.0.1:3000:google.com:80 ->
// local 192.168.0.1:3000
// remote google.com:80
// 127.0.0.1:1080:socks
// local 127.0.0.1:1080
// remote socks
// stdio:example.com:22
// local stdio
// remote example.com:22
// 1.1.1.1:53/udp
// local 127.0.0.1:53/udp
// remote 1.1.1.1:53/udp
type Remote struct {
LocalHost, LocalPort, LocalProto string
RemoteHost, RemotePort, RemoteProto string
Socks, Reverse, Stdio bool
}
const revPrefix = "R:"
func DecodeRemote(s string) (*Remote, error) {
reverse := false
if strings.HasPrefix(s, revPrefix) {
s = strings.TrimPrefix(s, revPrefix)
reverse = true
}
parts := regexp.MustCompile(`(\[[^\[\]]+\]|[^\[\]:]+):?`).FindAllStringSubmatch(s, -1)
if len(parts) <= 0 || len(parts) >= 5 {
return nil, errors.New("Invalid remote")
}
r := &Remote{Reverse: reverse}
//parse from back to front, to set 'remote' fields first,
//then to set 'local' fields second (allows the 'remote' side
//to provide the defaults)
for i := len(parts) - 1; i >= 0; i-- {
p := parts[i][1]
//remote portion is socks?
if i == len(parts)-1 && p == "socks" {
r.Socks = true
continue
}
//local portion is stdio?
if i == 0 && p == "stdio" {
r.Stdio = true
continue
}
p, proto := L4Proto(p)
if proto != "" {
if r.RemotePort == "" {
r.RemoteProto = proto
} else if r.LocalProto == "" {
r.LocalProto = proto
}
}
if isPort(p) {
if !r.Socks && r.RemotePort == "" {
r.RemotePort = p
}
r.LocalPort = p
continue
}
if !r.Socks && (r.RemotePort == "" && r.LocalPort == "") {
return nil, errors.New("Missing ports")
}
if !isHost(p) {
return nil, errors.New("Invalid host")
}
if !r.Socks && r.RemoteHost == "" {
r.RemoteHost = p
} else {
r.LocalHost = p
}
}
//remote string parsed, apply defaults...
if r.Socks {
//socks defaults
if r.LocalHost == "" {
r.LocalHost = "127.0.0.1"
}
if r.LocalPort == "" {
r.LocalPort = "1080"
}
} else {
//non-socks defaults
if r.LocalHost == "" {
r.LocalHost = "0.0.0.0"
}
if r.RemoteHost == "" {
r.RemoteHost = "127.0.0.1"
}
}
if r.RemoteProto == "" {
r.RemoteProto = "tcp"
}
if r.LocalProto == "" {
r.LocalProto = r.RemoteProto
}
if r.LocalProto != r.RemoteProto {
//TODO support cross protocol
//tcp <-> udp, is faily straight forward
//udp <-> tcp, is trickier since udp is stateless and tcp is not
return nil, errors.New("cross-protocol remotes are not supported yet")
}
if r.Socks && r.RemoteProto != "tcp" {
return nil, errors.New("only TCP SOCKS is supported")
}
if r.Stdio && r.Reverse {
return nil, errors.New("stdio cannot be reversed")
}
return r, nil
}
func isPort(s string) bool {
n, err := strconv.Atoi(s)
if err != nil {
return false
}
if n <= 0 || n > 65535 {
return false
}
return true
}
func isHost(s string) bool {
_, err := url.Parse("//" + s)
if err != nil {
return false
}
return true
}
var l4Proto = regexp.MustCompile(`(?i)\/(tcp|udp)$`)
//L4Proto extacts the layer-4 protocol from the given string
func L4Proto(s string) (head, proto string) {
if l4Proto.MatchString(s) {
l := len(s)
return strings.ToLower(s[:l-4]), s[l-3:]
}
return s, ""
}
//implement Stringer
func (r Remote) String() string {
sb := strings.Builder{}
if r.Reverse {
sb.WriteString(revPrefix)
}
sb.WriteString(strings.TrimPrefix(r.Local(), "0.0.0.0:"))
sb.WriteString("=>")
sb.WriteString(strings.TrimPrefix(r.Remote(), "127.0.0.1:"))
if r.RemoteProto == "udp" {
sb.WriteString("/udp")
}
return sb.String()
}
//Encode remote to a string
func (r Remote) Encode() string {
if r.LocalPort == "" {
r.LocalPort = r.RemotePort
}
local := r.Local()
remote := r.Remote()
if r.RemoteProto == "udp" {
remote += "/udp"
}
if r.Reverse {
return "R:" + local + ":" + remote
}
return local + ":" + remote
}
//Local is the decodable local portion
func (r Remote) Local() string {
if r.Stdio {
return "stdio"
}
if r.LocalHost == "" {
r.LocalHost = "0.0.0.0"
}
return r.LocalHost + ":" + r.LocalPort
}
//Remote is the decodable remote portion
func (r Remote) Remote() string {
if r.Socks {
return "socks"
}
if r.RemoteHost == "" {
r.RemoteHost = "127.0.0.1"
}
return r.RemoteHost + ":" + r.RemotePort
}
//UserAddr is checked when checking if a
//user has access to a given remote
func (r Remote) UserAddr() string {
if r.Reverse {
return "R:" + r.LocalHost + ":" + r.LocalPort
}
return r.RemoteHost + ":" + r.RemotePort
}
//CanListen checks if the port can be listened on
func (r Remote) CanListen() bool {
//valid protocols
switch r.LocalProto {
case "tcp":
conn, err := net.Listen("tcp", r.Local())
if err == nil {
conn.Close()
return true
}
return false
case "udp":
addr, err := net.ResolveUDPAddr("udp", r.Local())
if err != nil {
return false
}
conn, err := net.ListenUDP(r.LocalProto, addr)
if err == nil {
conn.Close()
return true
}
return false
}
//invalid
return false
}
type Remotes []*Remote
//Filter out forward reversed/non-reversed remotes
func (rs Remotes) Reversed(reverse bool) Remotes {
subset := Remotes{}
for _, r := range rs {
match := r.Reverse == reverse
if match {
subset = append(subset, r)
}
}
return subset
}
//Encode back into strings
func (rs Remotes) Encode() []string {
s := make([]string, len(rs))
for i, r := range rs {
s[i] = r.Encode()
}
return s
}
package settings
import (
"reflect"
"testing"
)
func TestRemoteDecode(t *testing.T) {
//test table
for i, test := range []struct {
Input string
Output Remote
Encoded string
}{
{
"3000",
Remote{
LocalPort: "3000",
RemoteHost: "127.0.0.1",
RemotePort: "3000",
},
"0.0.0.0:3000:127.0.0.1:3000",
},
{
"google.com:80",
Remote{
LocalPort: "80",
RemoteHost: "google.com",
RemotePort: "80",
},
"0.0.0.0:80:google.com:80",
},
{
"R:google.com:80",
Remote{
LocalPort: "80",
RemoteHost: "google.com",
RemotePort: "80",
Reverse: true,
},
"R:0.0.0.0:80:google.com:80",
},
{
"示例網站.com:80",
Remote{
LocalPort: "80",
RemoteHost: "示例網站.com",
RemotePort: "80",
},
"0.0.0.0:80:示例網站.com:80",
},
{
"socks",
Remote{
LocalHost: "127.0.0.1",
LocalPort: "1080",
Socks: true,
},
"127.0.0.1:1080:socks",
},
{
"127.0.0.1:1081:socks",
Remote{
LocalHost: "127.0.0.1",
LocalPort: "1081",
Socks: true,
},
"127.0.0.1:1081:socks",
},
{
"1.1.1.1:53/udp",
Remote{
LocalPort: "53",
LocalProto: "udp",
RemoteHost: "1.1.1.1",
RemotePort: "53",
RemoteProto: "udp",
},
"0.0.0.0:53:1.1.1.1:53/udp",
},
{
"localhost:5353:1.1.1.1:53/udp",
Remote{
LocalHost: "localhost",
LocalPort: "5353",
LocalProto: "udp",
RemoteHost: "1.1.1.1",
RemotePort: "53",
RemoteProto: "udp",
},
"localhost:5353:1.1.1.1:53/udp",
},
{
"[::1]:8080:google.com:80",
Remote{
LocalHost: "[::1]",
LocalPort: "8080",
RemoteHost: "google.com",
RemotePort: "80",
},
"[::1]:8080:google.com:80",
},
{
"R:[::]:3000:[::1]:3000",
Remote{
LocalHost: "[::]",
LocalPort: "3000",
RemoteHost: "[::1]",
RemotePort: "3000",
Reverse: true,
},
"R:[::]:3000:[::1]:3000",
},
} {
//expected defaults
expected := test.Output
if expected.LocalHost == "" {
expected.LocalHost = "0.0.0.0"
}
if expected.RemoteProto == "" {
expected.RemoteProto = "tcp"
}
if expected.LocalProto == "" {
expected.LocalProto = "tcp"
}
//compare
got, err := DecodeRemote(test.Input)
if err != nil {
t.Fatalf("decode #%d '%s' failed: %s", i+1, test.Input, err)
}
if !reflect.DeepEqual(got, &expected) {
t.Fatalf("decode #%d '%s' expected\n %#v\ngot\n %#v", i+1, test.Input, expected, got)
}
if e := got.Encode(); test.Encoded != e {
t.Fatalf("encode #%d '%s' expected\n %#v\ngot\n %#v", i+1, test.Input, test.Encoded, e)
}
}
}
package settings
import (
"regexp"
"strings"
)
var UserAllowAll = regexp.MustCompile("")
func ParseAuth(auth string) (string, string) {
if strings.Contains(auth, ":") {
pair := strings.SplitN(auth, ":", 2)
return pair[0], pair[1]
}
return "", ""
}
type User struct {
Name string
Pass string
Addrs []*regexp.Regexp
}
func (u *User) HasAccess(addr string) bool {
m := false
for _, r := range u.Addrs {
if r.MatchString(addr) {
m = true
break
}
}
return m
}
package settings
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"regexp"
"sync"
"github.com/fsnotify/fsnotify"
"github.com/jpillora/chisel/share/cio"
)
type Users struct {
sync.RWMutex
inner map[string]*User
}
func NewUsers() *Users {
return &Users{inner: map[string]*User{}}
}
// Len returns the numbers of users
func (u *Users) Len() int {
u.RLock()
l := len(u.inner)
u.RUnlock()
return l
}
// Get user from the index by key
func (u *Users) Get(key string) (*User, bool) {
u.RLock()
user, found := u.inner[key]
u.RUnlock()
return user, found
}
// Set a users into the list by specific key
func (u *Users) Set(key string, user *User) {
u.Lock()
u.inner[key] = user
u.Unlock()
}
// Del ete a users from the list
func (u *Users) Del(key string) {
u.Lock()
delete(u.inner, key)
u.Unlock()
}
// AddUser adds a users to the set
func (u *Users) AddUser(user *User) {
u.Set(user.Name, user)
}
// Reset all users to the given set,
// Use nil to remove all.
func (u *Users) Reset(users []*User) {
m := map[string]*User{}
for _, u := range users {
m[u.Name] = u
}
u.Lock()
u.inner = m
u.Unlock()
}
// UserIndex is a reloadable user source
type UserIndex struct {
*cio.Logger
*Users
configFile string
}
// NewUserIndex creates a source for users
func NewUserIndex(logger *cio.Logger) *UserIndex {
return &UserIndex{
Logger: logger.Fork("users"),
Users: NewUsers(),
}
}
// LoadUsers is responsible for loading users from a file
func (u *UserIndex) LoadUsers(configFile string) error {
u.configFile = configFile
u.Infof("Loading configuration file %s", configFile)
if err := u.loadUserIndex(); err != nil {
return err
}
if err := u.addWatchEvents(); err != nil {
return err
}
return nil
}
// watchEvents is responsible for watching for updates to the file and reloading
func (u *UserIndex) addWatchEvents() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if err := watcher.Add(u.configFile); err != nil {
return err
}
go func() {
for e := range watcher.Events {
if e.Op&fsnotify.Write != fsnotify.Write {
continue
}
if err := u.loadUserIndex(); err != nil {
u.Infof("Failed to reload the users configuration: %s", err)
} else {
u.Debugf("Users configuration successfully reloaded from: %s", u.configFile)
}
}
}()
return nil
}
// loadUserIndex is responsible for loading the users configuration
func (u *UserIndex) loadUserIndex() error {
if u.configFile == "" {
return errors.New("configuration file not set")
}
b, err := ioutil.ReadFile(u.configFile)
if err != nil {
return fmt.Errorf("Failed to read auth file: %s, error: %s", u.configFile, err)
}
var raw map[string][]string
if err := json.Unmarshal(b, &raw); err != nil {
return errors.New("Invalid JSON: " + err.Error())
}
users := []*User{}
for auth, remotes := range raw {
user := &User{}
user.Name, user.Pass = ParseAuth(auth)
if user.Name == "" {
return errors.New("Invalid user:pass string")
}
for _, r := range remotes {
if r == "" || r == "*" {
user.Addrs = append(user.Addrs, UserAllowAll)
} else {
re, err := regexp.Compile(r)
if err != nil {
return errors.New("Invalid address regex")
}
user.Addrs = append(user.Addrs, re)
}
}
users = append(users, user)
}
//swap
u.Reset(users)
return nil
}
package tunnel
import (
"bytes"
"context"
"errors"
"io/ioutil"
"log"
"os"
"sync"
"time"
"github.com/armon/go-socks5"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/settings"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
//Config a Tunnel
type Config struct {
*cio.Logger
Inbound bool
Outbound bool
Socks bool
KeepAlive time.Duration
}
//Tunnel represents an SSH tunnel with proxy capabilities.
//Both chisel client and server are Tunnels.
//chisel client has a single set of remotes, whereas
//chisel server has multiple sets of remotes (one set per client).
//Each remote has a 1:1 mapping to a proxy.
//Proxies listen, send data over ssh, and the other end of the ssh connection
//communicates with the endpoint and returns the response.
type Tunnel struct {
Config
//ssh connection
activeConnMut sync.RWMutex
activatingConn waitGroup
activeConn ssh.Conn
//proxies
proxyCount int
//internals
connStats cnet.ConnCount
socksServer *socks5.Server
}
//New Tunnel from the given Config
func New(c Config) *Tunnel {
c.Logger = c.Logger.Fork("tun")
t := &Tunnel{
Config: c,
}
t.activatingConn.Add(1)
//setup socks server (not listening on any port!)
extra := ""
if c.Socks {
sl := log.New(ioutil.Discard, "", 0)
if t.Logger.Debug {
sl = log.New(os.Stdout, "[socks]", log.Ldate|log.Ltime)
}
t.socksServer, _ = socks5.New(&socks5.Config{Logger: sl})
extra += " (SOCKS enabled)"
}
t.Debugf("Created%s", extra)
return t
}
//BindSSH provides an active SSH for use for tunnelling
func (t *Tunnel) BindSSH(ctx context.Context, c ssh.Conn, reqs <-chan *ssh.Request, chans <-chan ssh.NewChannel) error {
//link ctx to ssh-conn
go func() {
<-ctx.Done()
if c.Close() == nil {
t.Debugf("SSH cancelled")
}
t.activatingConn.DoneAll()
}()
//mark active and unblock
t.activeConnMut.Lock()
if t.activeConn != nil {
panic("double bind ssh")
}
t.activeConn = c
t.activeConnMut.Unlock()
t.activatingConn.Done()
//optional keepalive loop against this connection
if t.Config.KeepAlive > 0 {
go t.keepAliveLoop(c)
}
//block until closed
go t.handleSSHRequests(reqs)
go t.handleSSHChannels(chans)
t.Debugf("SSH connected")
err := c.Wait()
t.Debugf("SSH disconnected")
//mark inactive and block
t.activatingConn.Add(1)
t.activeConnMut.Lock()
t.activeConn = nil
t.activeConnMut.Unlock()
return err
}
//getSSH blocks while connecting
func (t *Tunnel) getSSH(ctx context.Context) ssh.Conn {
//cancelled already?
if isDone(ctx) {
return nil
}
t.activeConnMut.RLock()
c := t.activeConn
t.activeConnMut.RUnlock()
//connected already?
if c != nil {
return c
}
//connecting...
select {
case <-ctx.Done(): //cancelled
return nil
case <-time.After(settings.EnvDuration("SSH_WAIT", 35*time.Second)):
return nil //a bit longer than ssh timeout
case <-t.activatingConnWait():
t.activeConnMut.RLock()
c := t.activeConn
t.activeConnMut.RUnlock()
return c
}
}
func (t *Tunnel) activatingConnWait() <-chan struct{} {
ch := make(chan struct{})
go func() {
t.activatingConn.Wait()
close(ch)
}()
return ch
}
//BindRemotes converts the given remotes into proxies, and blocks
//until the caller cancels the context or there is a proxy error.
func (t *Tunnel) BindRemotes(ctx context.Context, remotes []*settings.Remote) error {
if len(remotes) == 0 {
return errors.New("no remotes")
}
if !t.Inbound {
return errors.New("inbound connections blocked")
}
proxies := make([]*Proxy, len(remotes))
for i, remote := range remotes {
p, err := NewProxy(t.Logger, t, t.proxyCount, remote)
if err != nil {
return err
}
proxies[i] = p
t.proxyCount++
}
//TODO: handle tunnel close
eg, ctx := errgroup.WithContext(ctx)
for _, proxy := range proxies {
p := proxy
eg.Go(func() error {
return p.Run(ctx)
})
}
t.Debugf("Bound proxies")
err := eg.Wait()
t.Debugf("Unbound proxies")
return err
}
func (t *Tunnel) keepAliveLoop(sshConn ssh.Conn) {
//ping forever
for {
time.Sleep(t.Config.KeepAlive)
_, b, err := sshConn.SendRequest("ping", true, nil)
if err != nil {
break
}
if len(b) > 0 && !bytes.Equal(b, []byte("pong")) {
t.Debugf("strange ping response")
break
}
}
//close ssh connection on abnormal ping
sshConn.Close()
}
package tunnel
import (
"context"
"io"
"net"
"sync"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/sizestr"
"golang.org/x/crypto/ssh"
)
//sshTunnel exposes a subset of Tunnel to subtypes
type sshTunnel interface {
getSSH(ctx context.Context) ssh.Conn
}
//Proxy is the inbound portion of a Tunnel
type Proxy struct {
*cio.Logger
sshTun sshTunnel
id int
count int
remote *settings.Remote
dialer net.Dialer
tcp *net.TCPListener
udp *udpListener
mu sync.Mutex
}
//NewProxy creates a Proxy
func NewProxy(logger *cio.Logger, sshTun sshTunnel, index int, remote *settings.Remote) (*Proxy, error) {
id := index + 1
p := &Proxy{
Logger: logger.Fork("proxy#%s", remote.String()),
sshTun: sshTun,
id: id,
remote: remote,
}
return p, p.listen()
}
func (p *Proxy) listen() error {
if p.remote.Stdio {
//TODO check if pipes active?
} else if p.remote.LocalProto == "tcp" {
addr, err := net.ResolveTCPAddr("tcp", p.remote.LocalHost+":"+p.remote.LocalPort)
if err != nil {
return p.Errorf("resolve: %s", err)
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return p.Errorf("tcp: %s", err)
}
p.Infof("Listening")
p.tcp = l
} else if p.remote.LocalProto == "udp" {
l, err := listenUDP(p.Logger, p.sshTun, p.remote)
if err != nil {
return err
}
p.Infof("Listening")
p.udp = l
} else {
return p.Errorf("unknown local proto")
}
return nil
}
//Run enables the proxy and blocks while its active,
//close the proxy by cancelling the context.
func (p *Proxy) Run(ctx context.Context) error {
if p.remote.Stdio {
return p.runStdio(ctx)
} else if p.remote.LocalProto == "tcp" {
return p.runTCP(ctx)
} else if p.remote.LocalProto == "udp" {
return p.udp.run(ctx)
}
panic("should not get here")
}
func (p *Proxy) runStdio(ctx context.Context) error {
defer p.Infof("Closed")
for {
p.pipeRemote(ctx, cio.Stdio)
select {
case <-ctx.Done():
return nil
default:
// the connection is not ready yet, keep waiting
}
}
}
func (p *Proxy) runTCP(ctx context.Context) error {
done := make(chan struct{})
//implements missing net.ListenContext
go func() {
select {
case <-ctx.Done():
p.tcp.Close()
case <-done:
}
}()
for {
src, err := p.tcp.Accept()
if err != nil {
select {
case <-ctx.Done():
//listener closed
err = nil
default:
p.Infof("Accept error: %s", err)
}
close(done)
return err
}
go p.pipeRemote(ctx, src)
}
}
func (p *Proxy) pipeRemote(ctx context.Context, src io.ReadWriteCloser) {
defer src.Close()
p.mu.Lock()
p.count++
cid := p.count
p.mu.Unlock()
l := p.Fork("conn#%d", cid)
l.Debugf("Open")
sshConn := p.sshTun.getSSH(ctx)
if sshConn == nil {
l.Debugf("No remote connection")
return
}
//ssh request for tcp connection for this proxy's remote
dst, reqs, err := sshConn.OpenChannel("chisel", []byte(p.remote.Remote()))
if err != nil {
l.Infof("Stream error: %s", err)
return
}
go ssh.DiscardRequests(reqs)
//then pipe
s, r := cio.Pipe(src, dst)
l.Debugf("Close (sent %s received %s)", sizestr.ToString(s), sizestr.ToString(r))
}
package tunnel
import (
"context"
"encoding/gob"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/sizestr"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
//listenUDP is a special listener which forwards packets via
//the bound ssh connection. tricky part is multiplexing lots of
//udp clients through the entry node. each will listen on its
//own source-port for a response:
// (random)
// src-1 1111->... dst-1 6345->7777
// src-2 2222->... <---> udp <---> udp <-> dst-1 7543->7777
// src-3 3333->... listener handler dst-1 1444->7777
//
//we must store these mappings (1111-6345, etc) in memory for a length
//of time, so that when the exit node receives a response on 6345, it
//knows to return it to 1111.
func listenUDP(l *cio.Logger, sshTun sshTunnel, remote *settings.Remote) (*udpListener, error) {
a, err := net.ResolveUDPAddr("udp", remote.Local())
if err != nil {
return nil, l.Errorf("resolve: %s", err)
}
conn, err := net.ListenUDP("udp", a)
if err != nil {
return nil, l.Errorf("listen: %s", err)
}
//ready
u := &udpListener{
Logger: l,
sshTun: sshTun,
remote: remote,
inbound: conn,
maxMTU: settings.EnvInt("UDP_MAX_SIZE", 9012),
}
u.Debugf("UDP max size: %d bytes", u.maxMTU)
return u, nil
}
type udpListener struct {
*cio.Logger
sshTun sshTunnel
remote *settings.Remote
inbound *net.UDPConn
outboundMut sync.Mutex
outbound *udpChannel
sent, recv int64
maxMTU int
}
func (u *udpListener) run(ctx context.Context) error {
defer u.inbound.Close()
//udp doesnt accept connections,
//udp simply forwards packets
//and therefore only needs to listen
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return u.runInbound(ctx)
})
eg.Go(func() error {
return u.runOutbound(ctx)
})
if err := eg.Wait(); err != nil {
u.Debugf("listen: %s", err)
return err
}
u.Debugf("Close (sent %s received %s)", sizestr.ToString(u.sent), sizestr.ToString(u.recv))
return nil
}
func (u *udpListener) runInbound(ctx context.Context) error {
buff := make([]byte, u.maxMTU)
for !isDone(ctx) {
//read from inbound udp
u.inbound.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := u.inbound.ReadFromUDP(buff)
if e, ok := err.(net.Error); ok && (e.Timeout() || e.Temporary()) {
continue
}
if err != nil {
return u.Errorf("read error: %w", err)
}
//upsert ssh channel
uc, err := u.getUDPChan(ctx)
if err != nil {
if strings.HasSuffix(err.Error(), "EOF") {
continue
}
return u.Errorf("inbound-udpchan: %w", err)
}
//send over channel, including source address
b := buff[:n]
if err := uc.encode(addr.String(), b); err != nil {
if strings.HasSuffix(err.Error(), "EOF") {
continue //dropped packet...
}
return u.Errorf("encode error: %w", err)
}
//stats
atomic.AddInt64(&u.sent, int64(n))
}
return nil
}
func (u *udpListener) runOutbound(ctx context.Context) error {
for !isDone(ctx) {
//upsert ssh channel
uc, err := u.getUDPChan(ctx)
if err != nil {
if strings.HasSuffix(err.Error(), "EOF") {
continue
}
return u.Errorf("outbound-udpchan: %w", err)
}
//receive from channel, including source address
p := udpPacket{}
if err := uc.decode(&p); err == io.EOF {
//outbound ssh disconnected, get new connection...
continue
} else if err != nil {
return u.Errorf("decode error: %w", err)
}
//write back to inbound udp
addr, err := net.ResolveUDPAddr("udp", p.Src)
if err != nil {
return u.Errorf("resolve error: %w", err)
}
n, err := u.inbound.WriteToUDP(p.Payload, addr)
if err != nil {
return u.Errorf("write error: %w", err)
}
//stats
atomic.AddInt64(&u.recv, int64(n))
}
return nil
}
func (u *udpListener) getUDPChan(ctx context.Context) (*udpChannel, error) {
u.outboundMut.Lock()
defer u.outboundMut.Unlock()
//cached
if u.outbound != nil {
return u.outbound, nil
}
//not cached, bind
sshConn := u.sshTun.getSSH(ctx)
if sshConn == nil {
return nil, fmt.Errorf("ssh-conn nil")
}
//ssh request for udp packets for this proxy's remote,
//just "udp" since the remote address is sent with each packet
dstAddr := u.remote.Remote() + "/udp"
rwc, reqs, err := sshConn.OpenChannel("chisel", []byte(dstAddr))
if err != nil {
return nil, fmt.Errorf("ssh-chan error: %s", err)
}
go ssh.DiscardRequests(reqs)
//remove on disconnect
go u.unsetUDPChan(sshConn)
//ready
o := &udpChannel{
r: gob.NewDecoder(rwc),
w: gob.NewEncoder(rwc),
c: rwc,
}
u.outbound = o
u.Debugf("aquired channel")
return o, nil
}
func (u *udpListener) unsetUDPChan(sshConn ssh.Conn) {
sshConn.Wait()
u.Debugf("lost channel")
u.outboundMut.Lock()
u.outbound = nil
u.outboundMut.Unlock()
}
package tunnel
import (
"fmt"
"io"
"net"
"strings"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/cnet"
"github.com/jpillora/chisel/share/settings"
"github.com/jpillora/sizestr"
"golang.org/x/crypto/ssh"
)
func (t *Tunnel) handleSSHRequests(reqs <-chan *ssh.Request) {
for r := range reqs {
switch r.Type {
case "ping":
r.Reply(true, []byte("pong"))
default:
t.Debugf("Unknown request: %s", r.Type)
}
}
}
func (t *Tunnel) handleSSHChannels(chans <-chan ssh.NewChannel) {
for ch := range chans {
go t.handleSSHChannel(ch)
}
}
func (t *Tunnel) handleSSHChannel(ch ssh.NewChannel) {
if !t.Config.Outbound {
t.Debugf("Denied outbound connection")
ch.Reject(ssh.Prohibited, "Denied outbound connection")
return
}
remote := string(ch.ExtraData())
//extract protocol
hostPort, proto := settings.L4Proto(remote)
udp := proto == "udp"
socks := hostPort == "socks"
if socks && t.socksServer == nil {
t.Debugf("Denied socks request, please enable socks")
ch.Reject(ssh.Prohibited, "SOCKS5 is not enabled")
return
}
sshChan, reqs, err := ch.Accept()
if err != nil {
t.Debugf("Failed to accept stream: %s", err)
return
}
stream := io.ReadWriteCloser(sshChan)
//cnet.MeterRWC(t.Logger.Fork("sshchan"), sshChan)
defer stream.Close()
go ssh.DiscardRequests(reqs)
l := t.Logger.Fork("conn#%d", t.connStats.New())
//ready to handle
t.connStats.Open()
l.Debugf("Open %s", t.connStats.String())
if socks {
err = t.handleSocks(stream)
} else if udp {
err = t.handleUDP(l, stream, hostPort)
} else {
err = t.handleTCP(l, stream, hostPort)
}
t.connStats.Close()
errmsg := ""
if err != nil && !strings.HasSuffix(err.Error(), "EOF") {
errmsg = fmt.Sprintf(" (error %s)", err)
}
l.Debugf("Close %s%s", t.connStats.String(), errmsg)
}
func (t *Tunnel) handleSocks(src io.ReadWriteCloser) error {
return t.socksServer.ServeConn(cnet.NewRWCConn(src))
}
func (t *Tunnel) handleTCP(l *cio.Logger, src io.ReadWriteCloser, hostPort string) error {
dst, err := net.Dial("tcp", hostPort)
if err != nil {
return err
}
s, r := cio.Pipe(src, dst)
l.Debugf("sent %s received %s", sizestr.ToString(s), sizestr.ToString(r))
return nil
}
package tunnel
import (
"encoding/gob"
"io"
"net"
"os"
"sync"
"time"
"github.com/jpillora/chisel/share/cio"
"github.com/jpillora/chisel/share/settings"
)
func (t *Tunnel) handleUDP(l *cio.Logger, rwc io.ReadWriteCloser, hostPort string) error {
conns := &udpConns{
Logger: l,
m: map[string]*udpConn{},
}
defer conns.closeAll()
h := &udpHandler{
Logger: l,
hostPort: hostPort,
udpChannel: &udpChannel{
r: gob.NewDecoder(rwc),
w: gob.NewEncoder(rwc),
c: rwc,
},
udpConns: conns,
maxMTU: settings.EnvInt("UDP_MAX_SIZE", 9012),
}
h.Debugf("UDP max size: %d bytes", h.maxMTU)
for {
p := udpPacket{}
if err := h.handleWrite(&p); err != nil {
return err
}
}
}
type udpHandler struct {
*cio.Logger
hostPort string
*udpChannel
*udpConns
maxMTU int
}
func (h *udpHandler) handleWrite(p *udpPacket) error {
if err := h.r.Decode(&p); err != nil {
return err
}
//dial now, we know we must write
conn, exists, err := h.udpConns.dial(p.Src, h.hostPort)
if err != nil {
return err
}
//however, we dont know if we must read...
//spawn up to <max-conns> go-routines to wait
//for a reply.
//TODO configurable
//TODO++ dont use go-routines, switch to pollable
// array of listeners where all listeners are
// sweeped periodically, removing the idle ones
const maxConns = 100
if !exists {
if h.udpConns.len() <= maxConns {
go h.handleRead(p, conn)
} else {
h.Debugf("exceeded max udp connections (%d)", maxConns)
}
}
_, err = conn.Write(p.Payload)
if err != nil {
return err
}
return nil
}
func (h *udpHandler) handleRead(p *udpPacket, conn *udpConn) {
//ensure connection is cleaned up
defer h.udpConns.remove(conn.id)
buff := make([]byte, h.maxMTU)
for {
//response must arrive within 15 seconds
deadline := settings.EnvDuration("UDP_DEADLINE", 15*time.Second)
conn.SetReadDeadline(time.Now().Add(deadline))
//read response
n, err := conn.Read(buff)
if err != nil {
if !os.IsTimeout(err) && err != io.EOF {
h.Debugf("read error: %s", err)
}
break
}
b := buff[:n]
//encode back over ssh connection
err = h.udpChannel.encode(p.Src, b)
if err != nil {
h.Debugf("encode error: %s", err)
return
}
}
}
type udpConns struct {
*cio.Logger
sync.Mutex
m map[string]*udpConn
}
func (cs *udpConns) dial(id, addr string) (*udpConn, bool, error) {
cs.Lock()
defer cs.Unlock()
conn, ok := cs.m[id]
if !ok {
c, err := net.Dial("udp", addr)
if err != nil {
return nil, false, err
}
conn = &udpConn{
id: id,
Conn: c, // cnet.MeterConn(cs.Logger.Fork(addr), c),
}
cs.m[id] = conn
}
return conn, ok, nil
}
func (cs *udpConns) len() int {
cs.Lock()
l := len(cs.m)
cs.Unlock()
return l
}
func (cs *udpConns) remove(id string) {
cs.Lock()
delete(cs.m, id)
cs.Unlock()
}
func (cs *udpConns) closeAll() {
cs.Lock()
for id, conn := range cs.m {
conn.Close()
delete(cs.m, id)
}
cs.Unlock()
}
type udpConn struct {
id string
net.Conn
}
package tunnel
import (
"context"
"encoding/gob"
"io"
)
type udpPacket struct {
Src string
Payload []byte
}
func init() {
gob.Register(&udpPacket{})
}
//udpChannel encodes/decodes udp payloads over a stream
type udpChannel struct {
r *gob.Decoder
w *gob.Encoder
c io.Closer
}
func (o *udpChannel) encode(src string, b []byte) error {
return o.w.Encode(udpPacket{
Src: src,
Payload: b,
})
}
func (o *udpChannel) decode(p *udpPacket) error {
return o.r.Decode(p)
}
func isDone(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
package tunnel
import (
"sync"
"sync/atomic"
)
type waitGroup struct {
inner sync.WaitGroup
n int32
}
func (w *waitGroup) Add(n int) {
atomic.AddInt32(&w.n, int32(n))
w.inner.Add(n)
}
func (w *waitGroup) Done() {
if n := atomic.LoadInt32(&w.n); n > 0 && atomic.CompareAndSwapInt32(&w.n, n, n-1) {
w.inner.Done()
}
}
func (w *waitGroup) DoneAll() {
for atomic.LoadInt32(&w.n) > 0 {
w.Done()
}
}
func (w *waitGroup) Wait() {
w.inner.Wait()
}
package chshare
//ProtocolVersion of chisel. When backwards
//incompatible changes are made, this will
//be incremented to signify a protocol
//mismatch.
var ProtocolVersion = "chisel-v3"
var BuildVersion = "0.0.0-src"
//chisel end-to-end test
//======================
//
// (direct)
// .--------------->----------------.
// / chisel chisel \
// request--->client:2001--->server:2002---->fileserver:3000
// \ /
// '--> crowbar:4001--->crowbar:4002'
// client server
//
// crowbar and chisel binaries should be in your PATH
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"os/exec"
"path"
"strconv"
"github.com/jpillora/chisel/share/cnet"
"time"
)
const ENABLE_CROWBAR = false
const (
B = 1
KB = 1000 * B
MB = 1000 * KB
GB = 1000 * MB
)
func run() {
flag.Parse()
args := flag.Args()
if len(args) == 0 {
fatal("go run main.go [test] or [bench]")
}
for _, a := range args {
switch a {
case "test":
test()
case "bench":
bench()
}
}
}
//test
func test() {
testTunnel("2001", 500)
testTunnel("2001", 50000)
}
//benchmark
func bench() {
benchSizes("3000")
benchSizes("2001")
if ENABLE_CROWBAR {
benchSizes("4001")
}
}
func benchSizes(port string) {
for size := 1; size <= 100*MB; size *= 10 {
testTunnel(port, size)
}
}
func testTunnel(port string, size int) {
t0 := time.Now()
resp, err := requestFile(port, size)
if err != nil {
fatal(err)
}
if resp.StatusCode != 200 {
fatal(err)
}
n, err := io.Copy(ioutil.Discard, resp.Body)
if err != nil {
fatal(err)
}
t1 := time.Now()
fmt.Printf(":%s => %d bytes in %s\n", port, size, t1.Sub(t0))
if int(n) != size {
fatalf("%d bytes expected, got %d", size, n)
}
}
//============================
func requestFile(port string, size int) (*http.Response, error) {
url := "http://127.0.0.1:" + port + "/" + strconv.Itoa(size)
// fmt.Println(url)
return http.Get(url)
}
func makeFileServer() *cnet.HTTPServer {
bsize := 3 * MB
bytes := make([]byte, bsize)
//filling huge buffer
for i := 0; i < len(bytes); i++ {
bytes[i] = byte(i)
}
s := cnet.NewHTTPServer()
s.Server.SetKeepAlivesEnabled(false)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rsize, _ := strconv.Atoi(r.URL.Path[1:])
for rsize >= bsize {
w.Write(bytes)
rsize -= bsize
}
w.Write(bytes[:rsize])
})
s.GoListenAndServe("0.0.0.0:3000", handler)
return s
}
//============================
func fatal(args ...interface{}) {
panic(fmt.Sprint(args...))
}
func fatalf(f string, args ...interface{}) {
panic(fmt.Sprintf(f, args...))
}
//global setup
func main() {
fs := makeFileServer()
go func() {
err := fs.Wait()
if err != nil {
fmt.Printf("fs server closed (%s)\n", err)
}
}()
if ENABLE_CROWBAR {
dir, _ := os.Getwd()
cd := exec.Command("crowbard",
`-listen`, "0.0.0.0:4002",
`-userfile`, path.Join(dir, "userfile"))
if err := cd.Start(); err != nil {
fatal(err)
}
go func() {
fatalf("crowbard: %v", cd.Wait())
}()
defer cd.Process.Kill()
time.Sleep(100 * time.Millisecond)
cf := exec.Command("crowbar-forward",
"-local=0.0.0.0:4001",
"-server=http://127.0.0.1:4002",
"-remote=127.0.0.1:3000",
"-username", "foo",
"-password", "bar")
if err := cf.Start(); err != nil {
fatal(err)
}
defer cf.Process.Kill()
}
time.Sleep(100 * time.Millisecond)
hd := exec.Command("chisel", "server",
// "-v",
"--key", "foobar",
"--port", "2002")
hd.Stdout = os.Stdout
if err := hd.Start(); err != nil {
fatal(err)
}
defer hd.Process.Kill()
time.Sleep(100 * time.Millisecond)
hf := exec.Command("chisel", "client",
// "-v",
"--fingerprint", "mOz4rg9zlQ409XAhhj6+fDDVwQMY42CL3Zg2W2oTYxA=",
"127.0.0.1:2002",
"2001:3000")
hf.Stdout = os.Stdout
if err := hf.Start(); err != nil {
fatal(err)
}
defer hf.Process.Kill()
time.Sleep(100 * time.Millisecond)
defer func() {
if r := recover(); r != nil {
log.Print(r)
}
}()
run()
fs.Close()
}
### Performance
With [crowbar](https://github.com/q3k/crowbar), a connection is tunneled by repeatedly querying the server with updates. This results in a large amount of HTTP and TCP connection overhead. Chisel overcomes this using WebSockets combined with [crypto/ssh](https://golang.org/x/crypto/ssh) to create hundreds of logical connections, resulting in **one** TCP connection per client.
In this simple benchmark, we have:
```
(direct)
.--------------->----------------.
/ chisel chisel \
request--->client:2001--->server:2002---->fileserver:3000
\ /
'--> crowbar:4001--->crowbar:4002'
client server
```
Note, we're using an in-memory "file" server on localhost for these tests
_direct_
```
:3000 => 1 bytes in 1.291417ms
:3000 => 10 bytes in 713.525µs
:3000 => 100 bytes in 562.48µs
:3000 => 1000 bytes in 595.445µs
:3000 => 10000 bytes in 1.053298ms
:3000 => 100000 bytes in 741.351µs
:3000 => 1000000 bytes in 1.367143ms
:3000 => 10000000 bytes in 8.601549ms
:3000 => 100000000 bytes in 76.3939ms
```
`chisel`
```
:2001 => 1 bytes in 1.351976ms
:2001 => 10 bytes in 1.106086ms
:2001 => 100 bytes in 1.005729ms
:2001 => 1000 bytes in 1.254396ms
:2001 => 10000 bytes in 1.139777ms
:2001 => 100000 bytes in 2.35437ms
:2001 => 1000000 bytes in 11.502673ms
:2001 => 10000000 bytes in 123.130246ms
:2001 => 100000000 bytes in 966.48636ms
```
~100MB in **~1 second**
`crowbar`
```
:4001 => 1 bytes in 3.335797ms
:4001 => 10 bytes in 1.453007ms
:4001 => 100 bytes in 1.811727ms
:4001 => 1000 bytes in 1.621525ms
:4001 => 10000 bytes in 5.20729ms
:4001 => 100000 bytes in 38.461926ms
:4001 => 1000000 bytes in 358.784864ms
:4001 => 10000000 bytes in 3.603206487s
:4001 => 100000000 bytes in 36.332395213s
```
~100MB in **36 seconds**
See `test/bench/main.go`
\ No newline at end of file
foo:bar
\ No newline at end of file
package e2e_test
import (
"testing"
chclient "github.com/jpillora/chisel/client"
chserver "github.com/jpillora/chisel/server"
)
//TODO tests for:
// - failed auth
// - dynamic auth (server add/remove user)
// - watch auth file
func TestAuth(t *testing.T) {
tmpPort1 := availablePort()
tmpPort2 := availablePort()
//setup server, client, fileserver
teardown := simpleSetup(t,
&chserver.Config{
KeySeed: "foobar",
Auth: "../bench/userfile",
},
&chclient.Config{
Remotes: []string{
"0.0.0.0:" + tmpPort1 + ":127.0.0.1:$FILEPORT",
"0.0.0.0:" + tmpPort2 + ":localhost:$FILEPORT",
},
Auth: "foo:bar",
})
defer teardown()
//test first remote
result, err := post("http://localhost:"+tmpPort1, "foo")
if err != nil {
t.Fatal(err)
}
if result != "foo!" {
t.Fatalf("expected exclamation mark added")
}
//test second remote
result, err = post("http://localhost:"+tmpPort2, "bar")
if err != nil {
t.Fatal(err)
}
if result != "bar!" {
t.Fatalf("expected exclamation mark added again")
}
}
package e2e_test
import (
"testing"
chclient "github.com/jpillora/chisel/client"
chserver "github.com/jpillora/chisel/server"
)
func TestBase(t *testing.T) {
tmpPort := availablePort()
//setup server, client, fileserver
teardown := simpleSetup(t,
&chserver.Config{},
&chclient.Config{
Remotes: []string{tmpPort + ":$FILEPORT"},
})
defer teardown()
//test remote
result, err := post("http://localhost:"+tmpPort, "foo")
if err != nil {
t.Fatal(err)
}
if result != "foo!" {
t.Fatalf("expected exclamation mark added")
}
}
func TestReverse(t *testing.T) {
tmpPort := availablePort()
//setup server, client, fileserver
teardown := simpleSetup(t,
&chserver.Config{
Reverse: true,
},
&chclient.Config{
Remotes: []string{"R:" + tmpPort + ":$FILEPORT"},
})
defer teardown()
//test remote (this goes through the server and out the client)
result, err := post("http://localhost:"+tmpPort, "foo")
if err != nil {
t.Fatal(err)
}
if result != "foo!" {
t.Fatalf("expected exclamation mark added")
}
}
package e2e_test
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"os"
"path"
"time"
chclient "github.com/jpillora/chisel/client"
chserver "github.com/jpillora/chisel/server"
)
type tlsConfig struct {
serverTLS *chserver.TLSConfig
clientTLS *chclient.TLSConfig
tmpDir string
}
func (t *tlsConfig) Close() {
if t.tmpDir != "" {
os.RemoveAll(t.tmpDir)
}
}
func newTestTLSConfig() (*tlsConfig, error) {
tlsConfig := &tlsConfig{}
_, serverCertPEM, serverKeyPEM, err := certGetCertificate(&certConfig{
hosts: []string{
"0.0.0.0",
"localhost",
},
extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
})
if err != nil {
return nil, err
}
_, clientCertPEM, clientKeyPEM, err := certGetCertificate(&certConfig{
extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
})
if err != nil {
return nil, err
}
tlsConfig.tmpDir, err = ioutil.TempDir("", "")
if err != nil {
return nil, err
}
dirServerCA := path.Join(tlsConfig.tmpDir, "server-ca")
if err := os.Mkdir(dirServerCA, 0777); err != nil {
return nil, err
}
pathServerCACrt := path.Join(dirServerCA, "client.crt")
if err := ioutil.WriteFile(pathServerCACrt, clientCertPEM, 0666); err != nil {
return nil, err
}
dirClientCA := path.Join(tlsConfig.tmpDir, "client-ca")
if err := os.Mkdir(dirClientCA, 0777); err != nil {
return nil, err
}
pathClientCACrt := path.Join(dirClientCA, "server.crt")
if err := ioutil.WriteFile(pathClientCACrt, serverCertPEM, 0666); err != nil {
return nil, err
}
dirServerCrt := path.Join(tlsConfig.tmpDir, "server-crt")
if err := os.Mkdir(dirServerCrt, 0777); err != nil {
return nil, err
}
pathServerCrtCrt := path.Join(dirServerCrt, "server.crt")
if err := ioutil.WriteFile(pathServerCrtCrt, serverCertPEM, 0666); err != nil {
return nil, err
}
pathServerCrtKey := path.Join(dirServerCrt, "server.key")
if err := ioutil.WriteFile(pathServerCrtKey, serverKeyPEM, 0666); err != nil {
return nil, err
}
dirClientCrt := path.Join(tlsConfig.tmpDir, "client-crt")
if err := os.Mkdir(dirClientCrt, 0777); err != nil {
return nil, err
}
pathClientCrtCrt := path.Join(dirClientCrt, "client.crt")
if err := ioutil.WriteFile(pathClientCrtCrt, clientCertPEM, 0666); err != nil {
return nil, err
}
pathClientCrtKey := path.Join(dirClientCrt, "client.key")
if err := ioutil.WriteFile(pathClientCrtKey, clientKeyPEM, 0666); err != nil {
return nil, err
}
// for self signed cert, it needs the server cert, for real cert, this need to be the trusted CA cert
tlsConfig.serverTLS = &chserver.TLSConfig{
CA: pathServerCACrt,
Cert: pathServerCrtCrt,
Key: pathServerCrtKey,
}
tlsConfig.clientTLS = &chclient.TLSConfig{
CA: pathClientCACrt,
Cert: pathClientCrtCrt,
Key: pathClientCrtKey,
}
return tlsConfig, nil
}
type certConfig struct {
signCA *x509.Certificate
isCA bool
hosts []string
validFrom *time.Time
validFor *time.Time
extKeyUsage []x509.ExtKeyUsage
rsaBits int
ecdsaCurve string
ed25519Key bool
}
func certGetCertificate(c *certConfig) (*x509.Certificate, []byte, []byte, error) {
var err error
var priv interface{}
switch c.ecdsaCurve {
case "":
if c.ed25519Key {
_, priv, err = ed25519.GenerateKey(rand.Reader)
} else {
rsaBits := c.rsaBits
if rsaBits == 0 {
rsaBits = 2048
}
priv, err = rsa.GenerateKey(rand.Reader, rsaBits)
}
case "P224":
priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case "P256":
priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case "P384":
priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case "P521":
priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
return nil, nil, nil, fmt.Errorf("Unrecognized elliptic curve: %q", c.ecdsaCurve)
}
if err != nil {
return nil, nil, nil, fmt.Errorf("Failed to generate private key: %v", err)
}
// ECDSA, ED25519 and RSA subject keys should have the DigitalSignature
// KeyUsage bits set in the x509.Certificate template
keyUsage := x509.KeyUsageDigitalSignature
// Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In
// the context of TLS this KeyUsage is particular to RSA key exchange and
// authentication.
if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
keyUsage |= x509.KeyUsageKeyEncipherment
}
notBefore := time.Now()
if c.validFrom != nil {
notBefore = *c.validFrom
}
notAfter := time.Now().Add(24 * time.Hour)
if c.validFor != nil {
notAfter = *c.validFor
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, nil, fmt.Errorf("Failed to generate serial number: %v", err)
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
OrganizationalUnit: []string{"test"},
Organization: []string{"Chisel"},
Country: []string{"us"},
Province: []string{"ma"},
Locality: []string{"Boston"},
CommonName: "localhost",
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: c.extKeyUsage,
BasicConstraintsValid: true,
}
for _, h := range c.hosts {
if ip := net.ParseIP(h); ip != nil {
cert.IPAddresses = append(cert.IPAddresses, ip)
} else {
cert.DNSNames = append(cert.DNSNames, h)
}
}
if c.isCA {
cert.IsCA = true
cert.KeyUsage |= x509.KeyUsageCertSign
}
ca := cert
if c.signCA != nil {
ca = c.signCA
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, certGetPublicKey(priv), priv)
if err != nil {
return nil, nil, nil, fmt.Errorf("Failed to create certificate: %v", err)
}
certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, nil, nil, fmt.Errorf("Unable to marshal private key: %v", err)
}
certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "PRIVATE KEY",
Bytes: privBytes,
})
return cert, certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}
func certGetPublicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}
package e2e_test
//TODO tests for:
// client -> CONNECT proxy -> server -> endpoint
// client -> SOCKS proxy -> server -> endpoint
package e2e_test
import (
"context"
"io/ioutil"
"log"
"net"
"net/http"
"strings"
"testing"
"time"
chclient "github.com/jpillora/chisel/client"
chserver "github.com/jpillora/chisel/server"
)
const debug = true
// test layout configuration
type testLayout struct {
server *chserver.Config
client *chclient.Config
fileServer bool
udpEcho bool
udpServer bool
}
func (tl *testLayout) setup(t *testing.T) (server *chserver.Server, client *chclient.Client, teardown func()) {
//start of the world
// goroutines := runtime.NumGoroutine()
//root cancel
ctx, cancel := context.WithCancel(context.Background())
//fileserver (fake endpoint)
filePort := availablePort()
if tl.fileServer {
fileAddr := "127.0.0.1:" + filePort
f := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
w.Write(append(b, '!'))
}),
}
fl, err := net.Listen("tcp", fileAddr)
if err != nil {
t.Fatal(err)
}
log.Printf("fileserver: listening on %s", fileAddr)
go func() {
f.Serve(fl)
cancel()
}()
go func() {
<-ctx.Done()
f.Close()
}()
}
//server
server, err := chserver.NewServer(tl.server)
if err != nil {
t.Fatal(err)
}
server.Debug = debug
port := availablePort()
if err := server.StartContext(ctx, "127.0.0.1", port); err != nil {
t.Fatal(err)
}
go func() {
server.Wait()
server.Infof("Closed")
cancel()
}()
//client (with defaults)
tl.client.Fingerprint = server.GetFingerprint()
if tl.server.TLS.Key != "" {
//the domain name has to be localhost to match the ssl cert
tl.client.Server = "https://localhost:" + port
} else {
tl.client.Server = "http://127.0.0.1:" + port
}
for i, r := range tl.client.Remotes {
//convert $FILEPORT into the allocated port for this test case
if tl.fileServer {
tl.client.Remotes[i] = strings.Replace(r, "$FILEPORT", filePort, 1)
}
}
client, err = chclient.NewClient(tl.client)
if err != nil {
t.Fatal(err)
}
client.Debug = debug
if err := client.Start(ctx); err != nil {
t.Fatal(err)
}
go func() {
client.Wait()
client.Infof("Closed")
cancel()
}()
//cancel context tree, and wait for both client and server to stop
teardown = func() {
cancel()
server.Wait()
client.Wait()
//confirm goroutines have been cleaned up
// time.Sleep(500 * time.Millisecond)
// TODO remove sleep
// d := runtime.NumGoroutine() - goroutines
// if d != 0 {
// pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
// t.Fatalf("goroutines left %d", d)
// }
}
//wait a bit...
//TODO: client signal API, similar to os.Notify(signal)
// wait for client setup
time.Sleep(50 * time.Millisecond)
//ready
return server, client, teardown
}
func simpleSetup(t *testing.T, s *chserver.Config, c *chclient.Config) context.CancelFunc {
conf := testLayout{
server: s,
client: c,
fileServer: true,
}
_, _, teardown := conf.setup(t)
return teardown
}
func post(url, body string) (string, error) {
resp, err := http.Post(url, "text/plain", strings.NewReader(body))
if err != nil {
return "", err
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(b), nil
}
func availablePort() string {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
log.Panic(err)
}
l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
log.Panic(err)
}
return port
}
package e2e_test
//TODO tests for:
// - SOCKS-client -> [client -> server SOCKS] -> endpoint
// - SOCKS-client -> [server -> client SOCKS] -> endpoint
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment