From 8ef8c8187fcea23761175f4e4d15485127e32803 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 6 Dec 2016 14:24:27 +0000 Subject: [PATCH] Add user ID handling methods --- userids.go | 15 +++++++++++++++ userids_test.go | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/userids.go b/userids.go index cf2f8c2..817a006 100644 --- a/userids.go +++ b/userids.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "strings" ) const lowerhex = "0123456789abcdef" @@ -44,6 +45,7 @@ func isValidEscapedChar(b byte) bool { } // EncodeUserLocalpart encodes the given string into Matrix-compliant user ID localpart form. +// See http://matrix.org/docs/spec/intro.html#mapping-from-other-character-sets // // This returns a string with only the characters "a-z0-9._=-". The uppercase range A-Z // are encoded using leading underscores ("_"). Characters outside the aforementioned ranges @@ -67,6 +69,7 @@ func EncodeUserLocalpart(str string) string { // DecodeUserLocalpart decodes the given string back into the original input string. // Returns an error if the given string is not a valid user ID localpart encoding. +// See http://matrix.org/docs/spec/intro.html#mapping-from-other-character-sets // // This decodes quoted-printable bytes back into UTF8, and unescapes casing. For // example: @@ -113,3 +116,15 @@ func DecodeUserLocalpart(str string) (string, error) { } return outputBuffer.String(), nil } + +// ExtractUserLocalpart extracts the localpart portion of a user ID. +// See http://matrix.org/docs/spec/intro.html#user-identifiers +func ExtractUserLocalpart(userID string) (string, error) { + if len(userID) == 0 || userID[0] != '@' { + return "", fmt.Errorf("%s is not a valid user id") + } + return strings.TrimPrefix( + strings.SplitN(userID, ":", 2)[0], // @foo:bar:8448 => [ "@foo", "bar:8448" ] + "@", // remove "@" prefix + ), nil +} diff --git a/userids_test.go b/userids_test.go index 26d183e..b896740 100644 --- a/userids_test.go +++ b/userids_test.go @@ -62,3 +62,25 @@ func TestDecodeUserLocalpartErrors(t *testing.T) { } } } + +var localparttests = []struct { + Input string + ExpectOutput string +}{ + {"@foo:bar", "foo"}, + {"@foo:bar:8448", "foo"}, + {"@foo.bar:baz.quuz", "foo.bar"}, +} + +func TestExtractUserLocalpart(t *testing.T) { + for _, u := range localparttests { + out, err := ExtractUserLocalpart(u.Input) + if err != nil { + t.Errorf("TestExtractUserLocalpart(%s) => Error: %s", u.Input, err) + continue + } + if out != u.ExpectOutput { + t.Errorf("TestExtractUserLocalpart(%s) => Got: %s, Want %s", u.Input, out, u.ExpectOutput) + } + } +}