From 665b35975ba06fbdc580365bb8c619e99c9d682a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 6 Dec 2016 14:01:15 +0000 Subject: [PATCH] Add user ID localpart encoding/decoding --- userids.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++ userids_test.go | 64 +++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 userids.go create mode 100644 userids_test.go diff --git a/userids.go b/userids.go new file mode 100644 index 0000000..cf2f8c2 --- /dev/null +++ b/userids.go @@ -0,0 +1,115 @@ +package gomatrix + +import ( + "bytes" + "encoding/hex" + "fmt" +) + +const lowerhex = "0123456789abcdef" + +// encode the given byte using quoted-printable encoding (e.g "=2f") +// and writes it to the buffer +// See https://golang.org/src/mime/quotedprintable/writer.go +func encode(buf *bytes.Buffer, b byte) { + buf.WriteByte('=') + buf.WriteByte(lowerhex[b>>4]) + buf.WriteByte(lowerhex[b&0x0f]) +} + +// escape the given alpha character and writes it to the buffer +func escape(buf *bytes.Buffer, b byte) { + buf.WriteByte('_') + if b == '_' { + buf.WriteByte('_') // another _ + } else { + buf.WriteByte(b + 0x20) // ASCII shift A-Z to a-z + } +} + +func shouldEncode(b byte) bool { + return b != '-' && b != '.' && b != '_' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z') +} + +func shouldEscape(b byte) bool { + return (b >= 'A' && b <= 'Z') || b == '_' +} + +func isValidByte(b byte) bool { + return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-' +} + +func isValidEscapedChar(b byte) bool { + return b == '_' || (b >= 'a' && b <= 'z') +} + +// EncodeUserLocalpart encodes the given string into Matrix-compliant user ID localpart form. +// +// This returns a string with only the characters "a-z0-9._=-". The uppercase range A-Z +// are encoded using leading underscores ("_"). Characters outside the aforementioned ranges +// (including literal underscores ("_") and equals ("=")) are encoded as UTF8 code points (NOT NCRs) +// and converted to lower-case hex with a leading "=". For example: +// Alph@Bet_50up => _alph=40_bet=5f50up +func EncodeUserLocalpart(str string) string { + strBytes := []byte(str) + var outputBuffer bytes.Buffer + for _, b := range strBytes { + if shouldEncode(b) { + encode(&outputBuffer, b) + } else if shouldEscape(b) { + escape(&outputBuffer, b) + } else { + outputBuffer.WriteByte(b) + } + } + return outputBuffer.String() +} + +// DecodeUserLocalpart decodes the given string back into the original input string. +// Returns an error if the given string is not a valid user ID localpart encoding. +// +// This decodes quoted-printable bytes back into UTF8, and unescapes casing. For +// example: +// _alph=40_bet=5f50up => Alph@Bet_50up +// Returns an error if the input string contains characters outside the +// range "a-z0-9._=-", has an invalid quote-printable byte (e.g. not hex), or has +// an invalid _ escaped byte (e.g. "_5"). +func DecodeUserLocalpart(str string) (string, error) { + strBytes := []byte(str) + var outputBuffer bytes.Buffer + for i := 0; i < len(strBytes); i++ { + b := strBytes[i] + if !isValidByte(b) { + return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + } + + if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ + if i+1 >= len(strBytes) { + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + } + if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + } + if strBytes[i+1] == '_' { + outputBuffer.WriteByte('_') + } else { + outputBuffer.WriteByte(strBytes[i+1] - 0x20) // ASCII shift a-z to A-Z + } + i++ // skip next byte since we just handled it + } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 + if i+2 >= len(strBytes) { + return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + } + dst := make([]byte, 1) + _, err := hex.Decode(dst, strBytes[i+1:i+3]) + if err != nil { + return "", err + } + outputBuffer.WriteByte(dst[0]) + i += 2 // skip next 2 bytes since we just handled it + } else { // pass through + outputBuffer.WriteByte(b) + } + } + return outputBuffer.String(), nil +} diff --git a/userids_test.go b/userids_test.go new file mode 100644 index 0000000..26d183e --- /dev/null +++ b/userids_test.go @@ -0,0 +1,64 @@ +package gomatrix + +import ( + "testing" +) + +var useridtests = []struct { + Input string + Output string +}{ + {"Alph@Bet_50up", "_alph=40_bet__50up"}, // The doc example + {"abcdef", "abcdef"}, // no-op + {"i_like_pie_", "i__like__pie__"}, // double underscore escaping + {"ABCDEF", "_a_b_c_d_e_f"}, // all-caps + {"!£", "=21=c2=a3"}, // punctuation and outside ascii range (U+00A3 => c2 a3) + {"___", "______"}, // literal underscores + {"hello-world.", "hello-world."}, // allowed punctuation + {"5+5=10", "5=2b5=3d10"}, // equals sign + {"東方Project", "=e6=9d=b1=e6=96=b9_project"}, // CJK mixed + {" foo bar", "=09foo=20bar"}, // whitespace (tab and space) +} + +func TestEncodeUserLocalpart(t *testing.T) { + for _, u := range useridtests { + out := EncodeUserLocalpart(u.Input) + if out != u.Output { + t.Fatalf("TestEncodeUserLocalpart(%s) => Got: %s Expected: %s", u.Input, out, u.Output) + } + } +} + +func TestDecodeUserLocalpart(t *testing.T) { + for _, u := range useridtests { + in, _ := DecodeUserLocalpart(u.Output) + if in != u.Input { + t.Fatalf("TestDecodeUserLocalpart(%s) => Got: %s Expected: %s", u.Output, in, u.Input) + } + } +} + +var errtests = []struct { + Input string +}{ + {"foo@bar"}, // invalid character @ + {"foo_5bar"}, // invalid character after _ + {"foo_._-bar"}, // multiple invalid characters after _ + {"foo=2Hbar"}, // invalid hex after = + {"foo=2hbar"}, // invalid hex after = (lower-case) + {"foo=======2fbar"}, // multiple invalid hex after = + {"foo=2"}, // end of string after = + {"foo_"}, // end of string after _ +} + +func TestDecodeUserLocalpartErrors(t *testing.T) { + for _, u := range errtests { + out, err := DecodeUserLocalpart(u.Input) + if out != "" { + t.Fatalf("TestDecodeUserLocalpartErrors(%s) => Got: %s Expected: empty string", u.Input, out) + } + if err == nil { + t.Fatalf("TestDecodeUserLocalpartErrors(%s) => Got: nil error Expected: error", u.Input) + } + } +}