Commit cbfdac43 by The Rock Guy Committed by Marcus Efraimsson

Postgres: Add support for scram sha 256 authentication (#18397)

Allow datasource connection to postgres with password "scram-sha-256" 
authentification.

Fixes #14662
parent d5e7ef23
......@@ -44,7 +44,7 @@ require (
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect
github.com/klauspost/compress v1.4.1 // indirect
github.com/klauspost/cpuid v1.2.0 // indirect
github.com/lib/pq v1.0.0
github.com/lib/pq v1.2.0
github.com/mattn/go-colorable v0.1.1 // indirect
github.com/mattn/go-isatty v0.0.7
github.com/mattn/go-sqlite3 v1.10.0
......@@ -52,7 +52,7 @@ require (
github.com/onsi/gomega v1.5.0 // indirect
github.com/opentracing/opentracing-go v1.1.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.8.1
github.com/pkg/errors v0.8.1 // indirect
github.com/prometheus/client_golang v0.9.2
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90
github.com/prometheus/common v0.2.0
......
......@@ -132,6 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
......
......@@ -70,17 +70,4 @@ postgresql_uninstall() {
sudo rm -rf /var/lib/postgresql
}
megacheck_install() {
# Lock megacheck version at $MEGACHECK_VERSION to prevent spontaneous
# new error messages in old code.
go get -d honnef.co/go/tools/...
git -C $GOPATH/src/honnef.co/go/tools/ checkout $MEGACHECK_VERSION
go install honnef.co/go/tools/cmd/megacheck
megacheck --version
}
golint_install() {
go get github.com/golang/lint/golint
}
$1
......@@ -10,7 +10,7 @@
## Docs
For detailed documentation and basic usage examples, please see the package
documentation at <http://godoc.org/github.com/lib/pq>.
documentation at <https://godoc.org/github.com/lib/pq>.
## Tests
......
......@@ -66,7 +66,7 @@ func (b *writeBuf) int16(n int) {
}
func (b *writeBuf) string(s string) {
b.buf = append(b.buf, (s + "\000")...)
b.buf = append(append(b.buf, s...), '\000')
}
func (b *writeBuf) byte(c byte) {
......
......@@ -2,7 +2,9 @@ package pq
import (
"bufio"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/binary"
......@@ -20,6 +22,7 @@ import (
"unicode"
"github.com/lib/pq/oid"
"github.com/lib/pq/scram"
)
// Common error types
......@@ -89,13 +92,25 @@ type Dialer interface {
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
type defaultDialer struct{}
// DialerContext is the context-aware dialer interface.
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type defaultDialer struct {
d net.Dialer
}
func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
return net.Dial(ntw, addr)
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
}
type conn struct {
......@@ -244,90 +259,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
}
}
// Open opens a new connection to the database. name is a connection string.
// Open opens a new connection to the database. dsn is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
}
// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.dialer = d
return c.open(context.Background())
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = ParseURL(name)
if err != nil {
return nil, err
}
}
if err := parseOpts(name, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
"ISO, MDY", datestyle))
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
o := c.opts
cn := &conn{
cn = &conn{
opts: o,
dialer: d,
dialer: c.dialer,
}
err = cn.handleDriverSettings(o)
if err != nil {
......@@ -335,13 +295,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
}
cn.handlePgpass(o)
cn.c, err = dial(d, o)
cn.c, err = dial(ctx, c.dialer, o)
if err != nil {
return nil, err
}
err = cn.ssl(o)
if err != nil {
if cn.c != nil {
cn.c.Close()
}
return nil, err
}
......@@ -364,10 +327,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
return cn, err
}
func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
network, address := network(o)
// SSL is not necessary or supported over UNIX domain sockets
if ntw == "unix" {
if network == "unix" {
o["sslmode"] = "disable"
}
......@@ -378,19 +341,30 @@ func dial(d Dialer, o values) (net.Conn, error) {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
conn, err := d.DialTimeout(ntw, addr, duration)
var conn net.Conn
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, duration)
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
return d.Dial(ntw, addr)
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
}
func network(o values) (string, string) {
......@@ -576,7 +550,7 @@ func (cn *conn) Commit() (err error) {
// would get the same behaviour if you issued a COMMIT in a failed
// transaction, so it's also the least surprising thing to do here.
if cn.txnStatus == txnStatusInFailedTransaction {
if err := cn.Rollback(); err != nil {
if err := cn.rollback(); err != nil {
return err
}
return ErrInFailedTransaction
......@@ -603,7 +577,10 @@ func (cn *conn) Rollback() (err error) {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
return cn.rollback()
}
func (cn *conn) rollback() (err error) {
cn.checkIsInTransaction(true)
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
......@@ -704,7 +681,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
res.rowsHeader = parsePortalRowDescribe(r)
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
......@@ -861,7 +838,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
rows.rowsHeader = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
}
......@@ -869,9 +846,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
rowsHeader: st.rowsHeader,
}, nil
}
......@@ -992,7 +967,6 @@ func (cn *conn) recv() (t byte, r *readBuf) {
if err != nil {
panic(err)
}
switch t {
case 'E':
panic(parseError(r))
......@@ -1163,6 +1137,55 @@ func (cn *conn) auth(r *readBuf, o values) {
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 10:
sc := scram.NewClient(sha256.New, o["user"], o["password"])
sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut := sc.Out()
w := cn.writeBuf('p')
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 11 {
errorf("unexpected authentication response: %q", t)
}
nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut = sc.Out()
w = cn.writeBuf('p')
w.bytes(scOut)
cn.send(w)
t, r = cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 12 {
errorf("unexpected authentication response: %q", t)
}
nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
default:
errorf("unknown authentication response: %d", code)
}
......@@ -1182,10 +1205,8 @@ var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
name string
colNames []string
colFmts []format
rowsHeader
colFmtData []byte
colTyps []fieldDesc
paramTyps []oid.Oid
closed bool
}
......@@ -1232,9 +1253,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
st.exec(v)
return &rows{
cn: st.cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
rowsHeader: st.rowsHeader,
}, nil
}
......@@ -1344,16 +1363,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
return driver.RowsAffected(n), commandTag
}
type rows struct {
cn *conn
finish func()
type rowsHeader struct {
colNames []string
colTyps []fieldDesc
colFmts []format
}
type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string
next *rowsHeader
}
func (rs *rows) Close() error {
......@@ -1440,7 +1465,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
return
case 'T':
rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
next := parsePortalRowDescribe(&rs.rb)
rs.next = &next
return io.EOF
default:
errorf("unexpected message after execute: %q", t)
......@@ -1449,10 +1475,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
func (rs *rows) HasNextResultSet() bool {
return !rs.done
hasNext := rs.next != nil && !rs.done
return hasNext
}
func (rs *rows) NextResultSet() error {
if rs.next == nil {
return io.EOF
}
rs.rowsHeader = *rs.next
rs.next = nil
return nil
}
......@@ -1475,6 +1507,39 @@ func QuoteIdentifier(name string) string {
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
......@@ -1630,13 +1695,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
}
}
func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
func (cn *conn) readPortalDescribeResponse() rowsHeader {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return nil, nil, nil
return rowsHeader{}
case 'E':
err := parseError(r)
cn.readReadyForQuery()
......@@ -1742,11 +1807,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
return
}
func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
func parsePortalRowDescribe(r *readBuf) rowsHeader {
n := r.int16()
colNames = make([]string, n)
colFmts = make([]format, n)
colTyps = make([]fieldDesc, n)
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
......@@ -1755,7 +1820,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
}
// parseEnviron tries to mimic some of libpq's environment handling
......
// +build go1.8
package pq
import (
......@@ -9,6 +7,7 @@ import (
"fmt"
"io"
"io/ioutil"
"time"
)
// Implement the "QueryerContext" interface
......@@ -76,13 +75,32 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
return tx, nil
}
func (cn *conn) Ping(ctx context.Context) error {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
rows, err := cn.simpleQuery("SELECT 'lib/pq ping test';")
if err != nil {
return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
}
rows.Close()
return nil
}
func (cn *conn) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
_ = cn.cancel()
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_ = cn.cancel(ctxCancel)
finished <- struct{}{}
case <-finished:
}
......@@ -97,8 +115,8 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
return nil
}
func (cn *conn) cancel() error {
c, err := dial(cn.dialer, cn.opts)
func (cn *conn) cancel(ctx context.Context) error {
c, err := dial(ctx, cn.dialer, cn.opts)
if err != nil {
return err
}
......
// +build go1.10
package pq
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"os"
"strings"
)
// Connector represents a fixed configuration for the pq driver with a given
......@@ -14,30 +16,95 @@ import (
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type connector struct {
name string
type Connector struct {
opts values
dialer Dialer
}
// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
return (&Driver{}).Open(c.name)
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
}
// Driver returnst the underlying driver of this Connector.
func (c *connector) Driver() driver.Driver {
func (c *Connector) Driver() driver.Driver {
return &Driver{}
}
var _ driver.Connector = &connector{}
// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given name. The returned connector can be used to create any number
// with the given dsn. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(name string) (driver.Connector, error) {
return &connector{name: name}, nil
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
return &Connector{opts: o, dialer: defaultDialer{}}, nil
}
......@@ -239,7 +239,7 @@ for more information). Note that the channel name will be truncated to 63
bytes by the PostgreSQL server.
You can find a complete, working example of Listener usage at
http://godoc.org/github.com/lib/pq/example/listen.
https://godoc.org/github.com/lib/pq/example/listen.
*/
package pq
......@@ -117,11 +117,10 @@ func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interfa
}
return i
case oid.T_float4, oid.T_float8:
bits := 64
if typ == oid.T_float4 {
bits = 32
}
f, err := strconv.ParseFloat(string(s), bits)
// We always use 64 bit parsing, regardless of whether the input text is for
// a float4 or float8, because clients expect float64s for all float datatypes
// and returning a 32-bit parsed float64 produces lossy results.
f, err := strconv.ParseFloat(string(s), 64)
if err != nil {
errorf("%s", err)
}
......
......@@ -478,13 +478,13 @@ func errRecoverNoErrBadConn(err *error) {
}
}
func (c *conn) errRecover(err *error) {
func (cn *conn) errRecover(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case runtime.Error:
c.bad = true
cn.bad = true
panic(v)
case *Error:
if v.Fatal() {
......@@ -493,7 +493,7 @@ func (c *conn) errRecover(err *error) {
*err = v
}
case *net.OpError:
c.bad = true
cn.bad = true
*err = v
case error:
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
......@@ -503,13 +503,13 @@ func (c *conn) errRecover(err *error) {
}
default:
c.bad = true
cn.bad = true
panic(fmt.Sprintf("unknown error: %#v", e))
}
// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
// mark the connection bad in database/sql.
if *err == driver.ErrBadConn {
c.bad = true
cn.bad = true
}
}
// Copyright (c) 2014 - Gustavo Niemeyer <gustavo@niemeyer.net>
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
//
// http://tools.ietf.org/html/rfc5802
//
package scram
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"encoding/base64"
"fmt"
"hash"
"strconv"
"strings"
)
// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
//
// A Client may be used within a SASL conversation with logic resembling:
//
// var in []byte
// var client = scram.NewClient(sha1.New, user, pass)
// for client.Step(in) {
// out := client.Out()
// // send out to server
// in := serverOut
// }
// if client.Err() != nil {
// // auth failed
// }
//
type Client struct {
newHash func() hash.Hash
user string
pass string
step int
out bytes.Buffer
err error
clientNonce []byte
serverNonce []byte
saltedPass []byte
authMsg bytes.Buffer
}
// NewClient returns a new SCRAM-* client with the provided hash algorithm.
//
// For SCRAM-SHA-256, for example, use:
//
// client := scram.NewClient(sha256.New, user, pass)
//
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
c := &Client{
newHash: newHash,
user: user,
pass: pass,
}
c.out.Grow(256)
c.authMsg.Grow(256)
return c
}
// Out returns the data to be sent to the server in the current step.
func (c *Client) Out() []byte {
if c.out.Len() == 0 {
return nil
}
return c.out.Bytes()
}
// Err returns the error that ocurred, or nil if there were no errors.
func (c *Client) Err() error {
return c.err
}
// SetNonce sets the client nonce to the provided value.
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
func (c *Client) SetNonce(nonce []byte) {
c.clientNonce = nonce
}
var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
// Step processes the incoming data from the server and makes the
// next round of data for the server available via Client.Out.
// Step returns false if there are no errors and more data is
// still expected.
func (c *Client) Step(in []byte) bool {
c.out.Reset()
if c.step > 2 || c.err != nil {
return false
}
c.step++
switch c.step {
case 1:
c.err = c.step1(in)
case 2:
c.err = c.step2(in)
case 3:
c.err = c.step3(in)
}
return c.step > 2 || c.err != nil
}
func (c *Client) step1(in []byte) error {
if len(c.clientNonce) == 0 {
const nonceLen = 16
buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
if _, err := rand.Read(buf[:nonceLen]); err != nil {
return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err)
}
c.clientNonce = buf[nonceLen:]
b64.Encode(c.clientNonce, buf[:nonceLen])
}
c.authMsg.WriteString("n=")
escaper.WriteString(&c.authMsg, c.user)
c.authMsg.WriteString(",r=")
c.authMsg.Write(c.clientNonce)
c.out.WriteString("n,,")
c.out.Write(c.authMsg.Bytes())
return nil
}
var b64 = base64.StdEncoding
func (c *Client) step2(in []byte) error {
c.authMsg.WriteByte(',')
c.authMsg.Write(in)
fields := bytes.Split(in, []byte(","))
if len(fields) != 3 {
return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in)
}
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0])
}
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1])
}
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
}
c.serverNonce = fields[0][2:]
if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
}
salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
n, err := b64.Decode(salt, fields[1][2:])
if err != nil {
return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1])
}
salt = salt[:n]
iterCount, err := strconv.Atoi(string(fields[2][2:]))
if err != nil {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
}
c.saltPassword(salt, iterCount)
c.authMsg.WriteString(",c=biws,r=")
c.authMsg.Write(c.serverNonce)
c.out.WriteString("c=biws,r=")
c.out.Write(c.serverNonce)
c.out.WriteString(",p=")
c.out.Write(c.clientProof())
return nil
}
func (c *Client) step3(in []byte) error {
var isv, ise bool
var fields = bytes.Split(in, []byte(","))
if len(fields) == 1 {
isv = bytes.HasPrefix(fields[0], []byte("v="))
ise = bytes.HasPrefix(fields[0], []byte("e="))
}
if ise {
return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:])
} else if !isv {
return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in)
}
if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:])
}
return nil
}
func (c *Client) saltPassword(salt []byte, iterCount int) {
mac := hmac.New(c.newHash, []byte(c.pass))
mac.Write(salt)
mac.Write([]byte{0, 0, 0, 1})
ui := mac.Sum(nil)
hi := make([]byte, len(ui))
copy(hi, ui)
for i := 1; i < iterCount; i++ {
mac.Reset()
mac.Write(ui)
mac.Sum(ui[:0])
for j, b := range ui {
hi[j] ^= b
}
}
c.saltedPass = hi
}
func (c *Client) clientProof() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Client Key"))
clientKey := mac.Sum(nil)
hash := c.newHash()
hash.Write(clientKey)
storedKey := hash.Sum(nil)
mac = hmac.New(c.newHash, storedKey)
mac.Write(c.authMsg.Bytes())
clientProof := mac.Sum(nil)
for i, b := range clientKey {
clientProof[i] ^= b
}
clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
b64.Encode(clientProof64, clientProof)
return clientProof64
}
func (c *Client) serverSignature() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Server Key"))
serverKey := mac.Sum(nil)
mac = hmac.New(c.newHash, serverKey)
mac.Write(c.authMsg.Bytes())
serverSignature := mac.Sum(nil)
encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
b64.Encode(encoded, serverSignature)
return encoded
}
......@@ -58,7 +58,13 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
if err != nil {
return nil, err
}
sslRenegotiation(&tlsConf)
// Accept renegotiation requests initiated by the backend.
//
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
// the default configuration of older versions has it enabled. Redshift
// also initiates renegotiations and cannot be reconfigured.
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
return func(conn net.Conn) (net.Conn, error) {
client := tls.Client(conn, &tlsConf)
......
// +build go1.7
package pq
import "crypto/tls"
// Accept renegotiation requests initiated by the backend.
//
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
// the default configuration of older versions has it enabled. Redshift
// also initiates renegotiations and cannot be reconfigured.
func sslRenegotiation(conf *tls.Config) {
conf.Renegotiation = tls.RenegotiateFreelyAsClient
}
// +build !go1.7
package pq
import "crypto/tls"
// Renegotiation is not supported by crypto/tls until Go 1.7.
func sslRenegotiation(*tls.Config) {}
......@@ -139,9 +139,10 @@ github.com/klauspost/compress/gzip
github.com/klauspost/compress/flate
# github.com/klauspost/cpuid v1.2.0
github.com/klauspost/cpuid
# github.com/lib/pq v1.0.0
# github.com/lib/pq v1.2.0
github.com/lib/pq
github.com/lib/pq/oid
github.com/lib/pq/scram
# github.com/mattn/go-colorable v0.1.1
github.com/mattn/go-colorable
# github.com/mattn/go-isatty v0.0.7
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment