mirror of https://github.com/matrix-org/gomatrix
Merge pull request #6 from matrix-org/kegan/userids
Add user ID localpart encoding/decoding
This commit is contained in:
commit
52f7775d99
|
@ -0,0 +1,130 @@
|
||||||
|
package gomatrix
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const lowerhex = "0123456789abcdef"
|
||||||
|
|
||||||
|
// encode the given byte using quoted-printable encoding (e.g "=2f")
|
||||||
|
// and writes it to the buffer
|
||||||
|
// See https://golang.org/src/mime/quotedprintable/writer.go
|
||||||
|
func encode(buf *bytes.Buffer, b byte) {
|
||||||
|
buf.WriteByte('=')
|
||||||
|
buf.WriteByte(lowerhex[b>>4])
|
||||||
|
buf.WriteByte(lowerhex[b&0x0f])
|
||||||
|
}
|
||||||
|
|
||||||
|
// escape the given alpha character and writes it to the buffer
|
||||||
|
func escape(buf *bytes.Buffer, b byte) {
|
||||||
|
buf.WriteByte('_')
|
||||||
|
if b == '_' {
|
||||||
|
buf.WriteByte('_') // another _
|
||||||
|
} else {
|
||||||
|
buf.WriteByte(b + 0x20) // ASCII shift A-Z to a-z
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldEncode(b byte) bool {
|
||||||
|
return b != '-' && b != '.' && b != '_' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z')
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldEscape(b byte) bool {
|
||||||
|
return (b >= 'A' && b <= 'Z') || b == '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidByte(b byte) bool {
|
||||||
|
return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidEscapedChar(b byte) bool {
|
||||||
|
return b == '_' || (b >= 'a' && b <= 'z')
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeUserLocalpart encodes the given string into Matrix-compliant user ID localpart form.
|
||||||
|
// See http://matrix.org/docs/spec/intro.html#mapping-from-other-character-sets
|
||||||
|
//
|
||||||
|
// This returns a string with only the characters "a-z0-9._=-". The uppercase range A-Z
|
||||||
|
// are encoded using leading underscores ("_"). Characters outside the aforementioned ranges
|
||||||
|
// (including literal underscores ("_") and equals ("=")) are encoded as UTF8 code points (NOT NCRs)
|
||||||
|
// and converted to lower-case hex with a leading "=". For example:
|
||||||
|
// Alph@Bet_50up => _alph=40_bet=5f50up
|
||||||
|
func EncodeUserLocalpart(str string) string {
|
||||||
|
strBytes := []byte(str)
|
||||||
|
var outputBuffer bytes.Buffer
|
||||||
|
for _, b := range strBytes {
|
||||||
|
if shouldEncode(b) {
|
||||||
|
encode(&outputBuffer, b)
|
||||||
|
} else if shouldEscape(b) {
|
||||||
|
escape(&outputBuffer, b)
|
||||||
|
} else {
|
||||||
|
outputBuffer.WriteByte(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputBuffer.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeUserLocalpart decodes the given string back into the original input string.
|
||||||
|
// Returns an error if the given string is not a valid user ID localpart encoding.
|
||||||
|
// See http://matrix.org/docs/spec/intro.html#mapping-from-other-character-sets
|
||||||
|
//
|
||||||
|
// This decodes quoted-printable bytes back into UTF8, and unescapes casing. For
|
||||||
|
// example:
|
||||||
|
// _alph=40_bet=5f50up => Alph@Bet_50up
|
||||||
|
// Returns an error if the input string contains characters outside the
|
||||||
|
// range "a-z0-9._=-", has an invalid quote-printable byte (e.g. not hex), or has
|
||||||
|
// an invalid _ escaped byte (e.g. "_5").
|
||||||
|
func DecodeUserLocalpart(str string) (string, error) {
|
||||||
|
strBytes := []byte(str)
|
||||||
|
var outputBuffer bytes.Buffer
|
||||||
|
for i := 0; i < len(strBytes); i++ {
|
||||||
|
b := strBytes[i]
|
||||||
|
if !isValidByte(b) {
|
||||||
|
return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
|
||||||
|
if i+1 >= len(strBytes) {
|
||||||
|
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
|
||||||
|
}
|
||||||
|
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
|
||||||
|
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
|
||||||
|
}
|
||||||
|
if strBytes[i+1] == '_' {
|
||||||
|
outputBuffer.WriteByte('_')
|
||||||
|
} else {
|
||||||
|
outputBuffer.WriteByte(strBytes[i+1] - 0x20) // ASCII shift a-z to A-Z
|
||||||
|
}
|
||||||
|
i++ // skip next byte since we just handled it
|
||||||
|
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
|
||||||
|
if i+2 >= len(strBytes) {
|
||||||
|
return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
|
||||||
|
}
|
||||||
|
dst := make([]byte, 1)
|
||||||
|
_, err := hex.Decode(dst, strBytes[i+1:i+3])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
outputBuffer.WriteByte(dst[0])
|
||||||
|
i += 2 // skip next 2 bytes since we just handled it
|
||||||
|
} else { // pass through
|
||||||
|
outputBuffer.WriteByte(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputBuffer.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractUserLocalpart extracts the localpart portion of a user ID.
|
||||||
|
// See http://matrix.org/docs/spec/intro.html#user-identifiers
|
||||||
|
func ExtractUserLocalpart(userID string) (string, error) {
|
||||||
|
if len(userID) == 0 || userID[0] != '@' {
|
||||||
|
return "", fmt.Errorf("%s is not a valid user id")
|
||||||
|
}
|
||||||
|
return strings.TrimPrefix(
|
||||||
|
strings.SplitN(userID, ":", 2)[0], // @foo:bar:8448 => [ "@foo", "bar:8448" ]
|
||||||
|
"@", // remove "@" prefix
|
||||||
|
), nil
|
||||||
|
}
|
|
@ -0,0 +1,86 @@
|
||||||
|
package gomatrix
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var useridtests = []struct {
|
||||||
|
Input string
|
||||||
|
Output string
|
||||||
|
}{
|
||||||
|
{"Alph@Bet_50up", "_alph=40_bet__50up"}, // The doc example
|
||||||
|
{"abcdef", "abcdef"}, // no-op
|
||||||
|
{"i_like_pie_", "i__like__pie__"}, // double underscore escaping
|
||||||
|
{"ABCDEF", "_a_b_c_d_e_f"}, // all-caps
|
||||||
|
{"!£", "=21=c2=a3"}, // punctuation and outside ascii range (U+00A3 => c2 a3)
|
||||||
|
{"___", "______"}, // literal underscores
|
||||||
|
{"hello-world.", "hello-world."}, // allowed punctuation
|
||||||
|
{"5+5=10", "5=2b5=3d10"}, // equals sign
|
||||||
|
{"東方Project", "=e6=9d=b1=e6=96=b9_project"}, // CJK mixed
|
||||||
|
{" foo bar", "=09foo=20bar"}, // whitespace (tab and space)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeUserLocalpart(t *testing.T) {
|
||||||
|
for _, u := range useridtests {
|
||||||
|
out := EncodeUserLocalpart(u.Input)
|
||||||
|
if out != u.Output {
|
||||||
|
t.Fatalf("TestEncodeUserLocalpart(%s) => Got: %s Expected: %s", u.Input, out, u.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeUserLocalpart(t *testing.T) {
|
||||||
|
for _, u := range useridtests {
|
||||||
|
in, _ := DecodeUserLocalpart(u.Output)
|
||||||
|
if in != u.Input {
|
||||||
|
t.Fatalf("TestDecodeUserLocalpart(%s) => Got: %s Expected: %s", u.Output, in, u.Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errtests = []struct {
|
||||||
|
Input string
|
||||||
|
}{
|
||||||
|
{"foo@bar"}, // invalid character @
|
||||||
|
{"foo_5bar"}, // invalid character after _
|
||||||
|
{"foo_._-bar"}, // multiple invalid characters after _
|
||||||
|
{"foo=2Hbar"}, // invalid hex after =
|
||||||
|
{"foo=2hbar"}, // invalid hex after = (lower-case)
|
||||||
|
{"foo=======2fbar"}, // multiple invalid hex after =
|
||||||
|
{"foo=2"}, // end of string after =
|
||||||
|
{"foo_"}, // end of string after _
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeUserLocalpartErrors(t *testing.T) {
|
||||||
|
for _, u := range errtests {
|
||||||
|
out, err := DecodeUserLocalpart(u.Input)
|
||||||
|
if out != "" {
|
||||||
|
t.Fatalf("TestDecodeUserLocalpartErrors(%s) => Got: %s Expected: empty string", u.Input, out)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("TestDecodeUserLocalpartErrors(%s) => Got: nil error Expected: error", u.Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var localparttests = []struct {
|
||||||
|
Input string
|
||||||
|
ExpectOutput string
|
||||||
|
}{
|
||||||
|
{"@foo:bar", "foo"},
|
||||||
|
{"@foo:bar:8448", "foo"},
|
||||||
|
{"@foo.bar:baz.quuz", "foo.bar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractUserLocalpart(t *testing.T) {
|
||||||
|
for _, u := range localparttests {
|
||||||
|
out, err := ExtractUserLocalpart(u.Input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestExtractUserLocalpart(%s) => Error: %s", u.Input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if out != u.ExpectOutput {
|
||||||
|
t.Errorf("TestExtractUserLocalpart(%s) => Got: %s, Want %s", u.Input, out, u.ExpectOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue