From 5320b561c592e875f0523760d4d20df8d66b30a7 Mon Sep 17 00:00:00 2001 From: raven Date: Mon, 20 Oct 2025 18:06:13 -0500 Subject: user and channel saving --- server/channel/channel.go | 242 ++++++++++++++++++++++++++++--------------- server/channel/command.go | 30 +++--- server/channel/membership.go | 35 ++++--- server/main.go | 19 +++- server/object/object.go | 124 +++++++++++++++++++++- server/object/tombstone.go | 10 ++ server/server/command.go | 25 ++--- server/server/server.go | 31 +++--- server/user/command.go | 1 + server/user/user.go | 124 +++++++++++++++++----- 10 files changed, 472 insertions(+), 169 deletions(-) (limited to 'server') diff --git a/server/channel/channel.go b/server/channel/channel.go index 3cb6f60..fdd4243 100644 --- a/server/channel/channel.go +++ b/server/channel/channel.go @@ -7,34 +7,35 @@ import ( "citrons.xyz/talk/server/validate" "citrons.xyz/talk/server/user" "strings" + "slices" "sort" + "maps" + "bytes" + "bufio" + bolt "go.etcd.io/bbolt" + "log" ) -type ChannelStore struct { +type ChannelKind struct { world *object.World - byName map[string]*Channel - directChannels map[string]*Channel + db *bolt.DB } type Channel struct { - store *ChannelStore + kind *ChannelKind id string name string isDirect bool - members map[string]Membership - messages []proto.Object - byId map[string]int - defaultMembership Membership Stream session.Stream + byId map[string]int + messages []proto.Object } -func NewStore(world *object.World) *ChannelStore { - return &ChannelStore { - world, make(map[string]*Channel), make(map[string]*Channel), - } +func Kind(world *object.World) *ChannelKind { + return &ChannelKind {world, world.DB()} } -func (cs *ChannelStore) CreateChannel(name string) (*Channel, *proto.Fail) { +func (cs *ChannelKind) CreateChannel(name string) (*Channel, *proto.Fail) { if cs.ByName(name) != nil { return nil, &proto.Fail { "name-taken", "", map[string]string {"": name}, @@ -46,39 +47,65 @@ func (cs *ChannelStore) CreateChannel(name string) (*Channel, *proto.Fail) { } } var c Channel - c.store = cs + c.kind = cs c.name = name - c.members = make(map[string]Membership) + c.id = proto.GenId() c.byId = make(map[string]int) - c.defaultMembership = DefaultMembership - cs.byName[validate.Fold(name)] = &c - c.id = cs.world.NewObject(&c) + err := cs.db.Update(func(tx *bolt.Tx) error { + chm, _ := tx.CreateBucketIfNotExists([]byte("channel membership")) + chm.CreateBucket([]byte(c.id)) + return nil + }) + if err != nil { + log.Fatal("error updating database: ", err) + } + c.SetDefaultMembership(DefaultMembership) + + c.Save() return &c, nil } -func (cs *ChannelStore) GetDirect(among []string) *Channel { - sort.Strings(among) - key := strings.Join(among, "\x00") - if cs.directChannels[key] == nil { - var c Channel - c.isDirect = true - c.store = cs - c.byId = make(map[string]int) - c.defaultMembership = DefaultMembership - c.members = make(map[string]Membership) - for _, member := range among { - c.members[member] = c.defaultMembership - } +func DirectHandle(among map[string]bool) string { + return strings.Join(slices.Sorted(maps.Keys(among)), "\x00") +} + +func (cs *ChannelKind) GetDirect(among map[string]bool) *Channel { + handle := DirectHandle(among) + switch ch := cs.world.Lookup("direct-channel", handle).(type) { + case *Channel: + return ch + } + var c Channel + c.isDirect = true + c.kind = cs + c.id = proto.GenId() + c.byId = make(map[string]int) + for member, _ := range among { + c.SetMembership(member, DefaultMembership) + } + c.Save() + return &c +} - cs.directChannels[key] = &c - c.id = cs.world.NewObject(&c) +func (cs *ChannelKind) ByName(name string) *Channel { + switch ch := cs.world.Lookup("channel", name).(type) { + case *Channel: + return ch + default: + return nil } - return cs.directChannels[key] } -func (cs *ChannelStore) ByName(name string) *Channel { - return cs.byName[validate.Fold(name)] +func (cs *ChannelKind) Undata(o proto.Object) object.Object { + log.Println("load: ", o) + var c Channel + c.kind = cs + c.id = o.Id + c.name = o.Fields[""] + c.isDirect = o.Kind == "direct-channel" + c.byId = make(map[string]int) + return &c } func (c *Channel) Name() string { @@ -90,11 +117,11 @@ func (c *Channel) NameFor(uid string) string { return c.name } else { var members []string - for member := range c.members { - if member == uid && len(c.members) > 1 { + for member, _ := range c.Members() { + if member == uid && len(c.Members()) > 1 { continue } - u := c.store.world.GetObject(member) + u := c.kind.world.GetObject(member) if u != nil { members = append(members, u.InfoFor(uid).Fields[""]) } @@ -104,6 +131,27 @@ func (c *Channel) NameFor(uid string) string { } } +func (c *Channel) Handle() string { + if !c.isDirect { + return c.name + } else { + members := make(map[string]bool) + for member := range c.Members() { + members[member] = true + } + return DirectHandle(members) + } +} + +func (c *Channel) Data() proto.Object { + log.Println("save: ", c.InfoFor("")) + return c.InfoFor("") +} + +func (c *Channel) Save() { + c.kind.world.PutObject(c.id, c) +} + func (c *Channel) Id() string { return c.id } @@ -114,18 +162,13 @@ func (c *Channel) Rename(name string) *proto.Fail { "invalid-name", "", map[string]string {"": name}, } } - if validate.Fold(name) == validate.Fold(c.name) { - c.name = name - return nil - } - if c.store.ByName(name) != nil { + if c.kind.ByName(name) != nil { return &proto.Fail { "name-taken", "", map[string]string {"": name}, } } - c.store.byName[validate.Fold(c.name)] = nil - c.store.byName[validate.Fold(name)] = c c.name = name + c.Save() return nil } @@ -138,75 +181,110 @@ func (c *Channel) Put(m proto.Object) proto.Object { if m.Fields["f"] == s.UserId { continue } - if c.members[s.UserId].See { + if c.GetMembership(s.UserId).See { s.Event(proto.NewCmd("p", c.id, m)) } } return m } -func (c *Channel) prune() { - if c.isDirect { - return - } - for m, _ := range c.members { - switch c.store.world.GetObject(m).(type) { - case *user.User: - default: - delete(c.members, m) - } - } -} - func (c *Channel) Join(u *user.User) *proto.Fail { - if c.members[u.Id()].Yes { + if c.GetMembership(u.Id()).Yes { return nil } - c.members[u.Id()] = c.defaultMembership - u.Channels[c.id] = true + c.SetMembership(u.Id(), c.GetDefaultMembership()) + // u.Channels[c.id] = true c.Put(proto.Object{"join", "", map[string]string {"f": u.Id()}}) return nil } func (c *Channel) Leave(u *user.User) *proto.Fail { - if !c.members[u.Id()].Yes { + if !c.GetMembership(u.Id()).Yes { return nil } - delete(c.members, u.Id()) - delete(u.Channels, c.id) + c.SetMembership(u.Id(), Membership {Yes: false}) + // delete(u.Channels, c.id) c.Put(proto.Object{"leave", "", map[string]string {"f": u.Id()}}) return nil } func (c *Channel) Members() map[string]Membership { - c.prune() - return c.members + result := make(map[string]Membership) + err := c.kind.db.View(func(tx *bolt.Tx) error { + channels := tx.Bucket([]byte("channel membership")) + members := channels.Bucket([]byte(c.id)) + return members.ForEach(func(k, v []byte) error { + var mship Membership + o, _ := proto.ReadObject(bufio.NewReader(bytes.NewReader(v))) + result[string(k)], _ = mship.Change(o) + return nil + }) + }) + if err != nil { + log.Fatal("error updating database: ", err) + } + return result } -func (c *Channel) SetMembership(u *user.User, m Membership) { - if c.members[u.Id()].Yes { - c.members[u.Id()] = m +func (c *Channel) GetMembership(uid string) Membership { + var mship Membership + err := c.kind.db.View(func(tx *bolt.Tx) error { + channels := tx.Bucket([]byte("channel membership")) + members := channels.Bucket([]byte(c.id)) + data := members.Get([]byte(uid)) + if data != nil { + o, _ := proto.ReadObject(bufio.NewReader(bytes.NewReader(data))) + mship.Undata(o) + } + return nil + }) + if err != nil { + log.Fatal("error updating database: ", err) } + return mship +} + +func (c *Channel) SetMembership(uid string, m Membership) { + err := c.kind.db.Update(func(tx *bolt.Tx) error { + channels := tx.Bucket([]byte("channel membership")) + members, _ := channels.CreateBucketIfNotExists([]byte(c.id)) + if m.Yes { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + proto.WriteObject(writer, m.GetInfo()) + writer.Flush() + return members.Put([]byte(uid), buf.Bytes()) + } else { + return members.Delete([]byte(uid)) + } + }) + if err != nil { + log.Fatal("error updating database: ", err) + } +} + +func (c *Channel) GetDefaultMembership() Membership { + return DefaultMembership +} + +func (c *Channel) SetDefaultMembership(m Membership) { } func (c *Channel) Delete() { c.Stream.Event(proto.NewCmd("delete", c.id)) c.Stream.UnsubscribeAll() - for m, _ := range c.members { - switch u := c.store.world.GetObject(m).(type) { - case *user.User: - u.Channels[c.id] = false - default: - } - } - delete(c.store.byName, validate.Fold(c.name)) - c.store.world.RemoveObject(c.id) - deleted := object.Tombstone { c.id, map[string]string {"": c.name, "kind": c.Kind()}, } - c.store.world.PutObject(c.id, deleted) + c.kind.world.PutObject(c.id, deleted) + err := c.kind.db.Update(func(tx *bolt.Tx) error { + channels := tx.Bucket([]byte("channel membership")) + return channels.DeleteBucket([]byte(c.id)) + }) + if err != nil { + log.Fatal("error updating database: ", err) + } } func (c *Channel) Kind() string { diff --git a/server/channel/command.go b/server/channel/command.go index 1171d2f..1184a5a 100644 --- a/server/channel/command.go +++ b/server/channel/command.go @@ -40,7 +40,7 @@ func (c *Channel) SendRequest(r session.Request) { return } - if !c.members[r.From.UserId].Put { + if !c.GetMembership(r.From.UserId).Put { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } @@ -60,7 +60,7 @@ func (c *Channel) SendRequest(r session.Request) { r.ReplyOk() case "join": - u := c.store.world.GetObject(r.From.UserId).(*user.User) + u := c.kind.world.GetObject(r.From.UserId).(*user.User) err := c.Join(u) if err != nil { r.Reply(err.Cmd()) @@ -69,7 +69,7 @@ func (c *Channel) SendRequest(r session.Request) { } case "leave": - u := c.store.world.GetObject(r.From.UserId).(*user.User) + u := c.kind.world.GetObject(r.From.UserId).(*user.User) err := c.Leave(u) if err != nil { r.Reply(err.Cmd()) @@ -78,7 +78,7 @@ func (c *Channel) SendRequest(r session.Request) { } case "delete": - if !c.members[r.From.UserId].Op { + if !c.GetMembership(r.From.UserId).Op { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } @@ -105,7 +105,7 @@ func (c *Channel) SendRequest(r session.Request) { } } - if !c.members[r.From.UserId].Update { + if !c.GetMembership(r.From.UserId).Update { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } @@ -120,13 +120,13 @@ func (c *Channel) SendRequest(r session.Request) { r.ReplyOk() case "list": - if !c.members[r.From.UserId].Yes { + if !c.GetMembership(r.From.UserId).Yes { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } cmd := proto.NewCmd("list", c.Id()) for m, _ := range c.Members() { - u := c.store.world.GetObject(m).(*user.User) + u := c.kind.world.GetObject(m).(*user.User) cmd.Args = append(cmd.Args, u.InfoFor(r.From.UserId)) } r.Reply(cmd) @@ -184,7 +184,7 @@ func (c *Channel) SendRequest(r session.Request) { max = len(c.messages) } - p := c.members[r.From.UserId] + p := c.GetMembership(r.From.UserId) if !p.History || !p.See { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return @@ -204,18 +204,18 @@ func (c *Channel) SendRequest(r session.Request) { return } - if !c.members[r.From.UserId].Yes { + if !c.GetMembership(r.From.UserId).Yes { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } - if !c.members[m.Id].Yes { + if !c.GetMembership(m.Id).Yes { r.Reply(proto.Fail{ "not-in-channel", "", map[string]string {"": m.Id}, }.Cmd()) return } - i := c.members[m.Id].GetInfo() + i := c.GetMembership(m.Id).GetInfo() i.Fields[""] = m.Id r.Reply(proto.NewCmd("i", "", i)) @@ -232,23 +232,23 @@ func (c *Channel) SendRequest(r session.Request) { r.ReplyInvalid() return } - new, err := c.members[id].Change(o) + new, err := c.GetMembership(id).Change(o) if err != nil { r.Reply(err.Cmd()) return } - if !c.members[r.From.UserId].Op { + if !c.GetMembership(r.From.UserId).Op { r.Reply(proto.Fail{"forbidden", "", nil}.Cmd()) return } - if !c.members[id].Yes { + if !c.GetMembership(id).Yes { r.Reply(proto.Fail{ "not-in-channel", "", map[string]string {"": id}, }.Cmd()) return } - c.members[id] = new + c.SetMembership(id, new) c.Put(o) i := new.GetInfo() diff --git a/server/channel/membership.go b/server/channel/membership.go index 3a44517..27c9bed 100644 --- a/server/channel/membership.go +++ b/server/channel/membership.go @@ -35,35 +35,42 @@ var CreatorMembership = Membership { Op: true, } -func (m Membership) Change(spec proto.Object) (Membership, *proto.Fail) { - new := m +func (m *Membership) Undata(spec proto.Object) { + m.Yes = true for k, v := range spec.Fields { - var field *bool + val := v == "yes" switch k { case "see": - field = &new.See + m.See = val case "put": - field = &new.Put + m.Put = val case "history": - field = &new.History + m.History = val case "moderate": - field = &new.Moderate + m.Moderate = val case "update": - field = &new.Update - case "": - continue + m.Update = val + case "op": + m.Op = val + } + } +} + +func (m Membership) Change(spec proto.Object) (Membership, *proto.Fail) { + new := m + for k, v := range spec.Fields { + switch k { + case "see", "put", "history", "moderate", "update", "": default: return new, &proto.Fail{"invalid", "", nil} } switch v { - case "yes": - *field = true - case "no": - *field = false + case "yes", "no": default: return new, &proto.Fail{"invalid", "", nil} } } + new.Undata(spec) return new, nil } diff --git a/server/main.go b/server/main.go index 6f75380..948de08 100644 --- a/server/main.go +++ b/server/main.go @@ -1,7 +1,22 @@ package main -import "citrons.xyz/talk/server/server" +import ( + "citrons.xyz/talk/server/server" + "flag" + "log" + bolt "go.etcd.io/bbolt" +) func main() { - server.Serve() + dbFile := flag.String("db", "./talk.db", "database file location") + address := flag.String("listen", ":27508", "address to listen on") + flag.Parse() + + db, err := bolt.Open(*dbFile, 0600, nil) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + server.Serve(db, *address) } diff --git a/server/object/object.go b/server/object/object.go index e0d0239..e05ae0f 100644 --- a/server/object/object.go +++ b/server/object/object.go @@ -3,31 +3,143 @@ package object import ( "citrons.xyz/talk/proto" "citrons.xyz/talk/server/session" + "citrons.xyz/talk/server/validate" + "bufio" + "bytes" + "log" + bolt "go.etcd.io/bbolt" ) type Object interface { SendRequest(session.Request) InfoFor(uid string) proto.Object + Data() proto.Object +} + +type HasHandle interface { + Handle() string +} + +type Kind interface { + Undata(o proto.Object) Object } type World struct { + db *bolt.DB + kinds map[string]Kind objects map[string]Object } -func NewWorld() *World { - return &World {make(map[string]Object)} +func NewWorld(db *bolt.DB) *World { + w := &World {db, make(map[string]Kind), make(map[string]Object)} + w.AddObjectKind("gone", TombstoneKind{}) + return w +} + +func (w *World) AddObjectKind(name string, kind Kind) { + w.kinds[name] = kind +} + +func (w *World) getData(id string) proto.Object { + var data []byte + err := w.db.View(func (tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("world")) + if bucket != nil { + data = bucket.Get([]byte(id)) + } + return nil + }) + if err != nil { + log.Fatal("reading database: ", err) + } + if len(data) == 0 { + return proto.Object {} + } + o, err := proto.ReadObject(bufio.NewReader(bytes.NewReader(data))) + if err != nil { + panic(err) + } + return o +} + +func (w *World) setData(id string, o proto.Object) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + proto.WriteObject(writer, o) + writer.Flush() + err := w.db.Update(func (tx *bolt.Tx) error { + bucket, _ := tx.CreateBucketIfNotExists([]byte("world")) + return bucket.Put([]byte(id), buf.Bytes()) + }) + if err != nil { + log.Fatal("updating database: ", err) + } } func (w *World) GetObject(id string) Object { + if w.objects[id] == nil { + o := w.getData(id) + if o.Kind != "" { + w.objects[id] = w.kinds[o.Kind].Undata(o) + } + } return w.objects[id] } func (w *World) PutObject(id string, o Object) { w.objects[id] = o + if id == "" { + return + } + switch h := o.(type) { + case HasHandle: + err := w.db.Update(func(tx *bolt.Tx) error { + kinds, _ := tx.CreateBucketIfNotExists([]byte("kinds")) + kind, _ := kinds.CreateBucketIfNotExists([]byte(o.Data().Kind)) + byHandle, _ := kind.CreateBucketIfNotExists([]byte("by handle")) + byId, _ := kind.CreateBucketIfNotExists([]byte("by id")) + + existing := byId.Get([]byte(id)) + if len(existing) != 0 { + byHandle.Delete(existing) + } + handle := []byte(validate.Fold(h.Handle())) + log.Println(string(handle), o.Data().Kind) + byHandle.Put(handle, []byte(id)) + byId.Put([]byte(id), handle) + + return nil + }) + if err != nil { + log.Fatal("updating database: ", err) + } + } + w.setData(id, o.Data()) } -func (w *World) RemoveObject(id string) { - w.objects[id] = nil +func (w *World) Lookup(kind string, handle string) Object { + handle = validate.Fold(handle) + + var id string + err := w.db.View(func(tx *bolt.Tx) error { + kinds := tx.Bucket([]byte("kinds")) + if kinds == nil { + return nil + } + kind := kinds.Bucket([]byte(kind)) + if kind == nil { + return nil + } + id = string(kind.Bucket([]byte("by handle")).Get([]byte(handle))) + return nil + }) + if err != nil { + log.Fatal("reading database: ", err) + } + if id != "" { + return w.GetObject(id) + } + return nil } func (w *World) NewObject(o Object) string { @@ -35,3 +147,7 @@ func (w *World) NewObject(o Object) string { w.PutObject(id, o) return id } + +func (w *World) DB() *bolt.DB { + return w.db +} diff --git a/server/object/tombstone.go b/server/object/tombstone.go index 05d9600..a2872f6 100644 --- a/server/object/tombstone.go +++ b/server/object/tombstone.go @@ -10,7 +10,17 @@ type Tombstone struct { Fields map[string]string } +type TombstoneKind struct {} + +func (t TombstoneKind) Undata(o proto.Object) Object { + return Tombstone {Id: o.Id, Fields: o.Fields} +} + func (t Tombstone) InfoFor(uid string) proto.Object { + return t.Data() +} + +func (t Tombstone) Data() proto.Object { return proto.Object {"gone", t.Id, t.Fields} } diff --git a/server/server/command.go b/server/server/command.go index c1b3f24..c1b28b9 100644 --- a/server/server/command.go +++ b/server/server/command.go @@ -26,12 +26,11 @@ func (s *server) SendRequest(r session.Request) { r.ReplyInvalid() return } - user, err := s.userStore.CreateUser(auth.Fields[""]) + user, err := s.userKind.CreateUser(auth.Fields[""]) if err != nil { r.Reply(err.Cmd()) return } - user.Anonymous = true r.Reply(proto.NewCmd("you-are", "", user.InfoFor(r.From.UserId))) r.From.UserId = user.Id() default: @@ -54,14 +53,14 @@ func (s *server) SendRequest(r session.Request) { var info proto.Object switch o.Kind { case "u": - u := s.userStore.ByName(name) + u := s.userKind.ByName(name) if u == nil { r.Reply(proto.Fail{"unknown-name", "", nil}.Cmd()) return } info = u.InfoFor(r.From.UserId) case "channel": - c := s.channelStore.ByName(name) + c := s.channelKind.ByName(name) if c == nil { r.Reply(proto.Fail{"unknown-name", "", nil}.Cmd()) return @@ -97,14 +96,14 @@ func (s *server) SendRequest(r session.Request) { return } } - c, err := s.channelStore.CreateChannel(name) + c, err := s.channelKind.CreateChannel(name) if err != nil { r.Reply(err.Cmd()) return } u := s.world.GetObject(r.From.UserId).(*user.User) c.Join(u) - c.SetMembership(u, channel.CreatorMembership) + c.SetMembership(u.Id(), channel.CreatorMembership) r.Reply(proto.NewCmd("create", "", c.InfoFor(r.From.UserId))) default: r.ReplyInvalid() @@ -119,18 +118,14 @@ func (s *server) SendRequest(r session.Request) { r.ReplyInvalid() return } - among := []string {r.From.UserId} - duplicate := make(map[string]bool) + + among := make(map[string]bool) + among[r.From.UserId] = true for _, member := range r.Cmd.Args { if member.Kind != "u" { r.ReplyInvalid() return } - if duplicate[member.Fields[""]] { - r.ReplyInvalid() - return - } - duplicate[member.Fields[""]] = true u := s.world.GetObject(member.Id) switch u.(type) { case *user.User: @@ -138,9 +133,9 @@ func (s *server) SendRequest(r session.Request) { r.Reply(proto.Fail{"bad-target", "", nil}.Cmd()) return } - among = append(among, member.Id) + among[member.Id] = true } - c := s.channelStore.GetDirect(among) + c := s.channelKind.GetDirect(among) r.Reply(proto.NewCmd("direct", "", c.InfoFor(r.From.UserId))) case "channels": diff --git a/server/server/server.go b/server/server/server.go index 8ec7201..60c6c32 100644 --- a/server/server/server.go +++ b/server/server/server.go @@ -10,6 +10,7 @@ import ( "citrons.xyz/talk/server/object" "citrons.xyz/talk/server/user" "citrons.xyz/talk/server/channel" + bolt "go.etcd.io/bbolt" ) type server struct { @@ -17,8 +18,8 @@ type server struct { clients chan *session.Session disconnects chan *session.Session world *object.World - userStore *user.UserStore - channelStore *channel.ChannelStore + userKind *user.UserKind + channelKind *channel.ChannelKind } func (s *server) mainLoop() { @@ -52,11 +53,7 @@ func (s *server) onConnect(sesh *session.Session) { func (s *server) onDisconnect(sesh *session.Session) { if sesh.UserId != "" { u := s.world.GetObject(sesh.UserId).(*user.User) - if u.Anonymous { - for c, _ := range u.Channels { - c := s.world.GetObject(c).(*channel.Channel) - c.Leave(u) - } + if u.IsAnonymous() { u.Delete() } } @@ -65,8 +62,12 @@ func (s *server) onDisconnect(sesh *session.Session) { } } -func Serve() { - ln, err := net.Listen("tcp", ":27508") +func (s *server) Data() proto.Object { + return proto.Object {} +} + +func Serve(db *bolt.DB, address string) { + ln, err := net.Listen("tcp", address) if err != nil { log.Fatal("Listen: ", err) } @@ -76,12 +77,18 @@ func Serve() { srv.requests = make(chan session.Request) srv.clients = make(chan *session.Session) srv.disconnects = make(chan *session.Session) - srv.world = object.NewWorld() - srv.userStore = user.NewStore(srv.world) - srv.channelStore = channel.NewStore(srv.world) + srv.world = object.NewWorld(db) + srv.userKind = user.Kind(srv.world) + srv.channelKind = channel.Kind(srv.world) + + srv.world.AddObjectKind("u", srv.userKind) + srv.world.AddObjectKind("channel", srv.channelKind) + srv.world.AddObjectKind("direct-channel", srv.channelKind) srv.world.PutObject("", &srv) + srv.userKind.DeleteAnonUsers() + go func() { for { conn, err := ln.Accept() diff --git a/server/user/command.go b/server/user/command.go index 14a8b57..47a3efb 100644 --- a/server/user/command.go +++ b/server/user/command.go @@ -47,6 +47,7 @@ func (u *User) SendRequest(r session.Request) { return } } + u.Save() u.Stream.Event(r.Cmd) r.ReplyOk() diff --git a/server/user/user.go b/server/user/user.go index 17063b4..46eda35 100644 --- a/server/user/user.go +++ b/server/user/user.go @@ -5,29 +5,31 @@ import ( "citrons.xyz/talk/server/session" "citrons.xyz/talk/server/validate" "citrons.xyz/talk/proto" + bolt "go.etcd.io/bbolt" + "log" ) -type UserStore struct { +type UserKind struct { world *object.World - byName map[string]*User + db *bolt.DB } type User struct { - store *UserStore + kind *UserKind name string id string status string description string Stream session.Stream - Channels map[string]bool - Anonymous bool + anonymous bool + Channels map[string]bool // TODO: remove } -func NewStore(world *object.World) *UserStore { - return &UserStore {world, make(map[string]*User)} +func Kind(world *object.World) *UserKind { + return &UserKind {world, world.DB()} } -func (us *UserStore) CreateUser(name string) (*User, *proto.Fail) { +func (us *UserKind) CreateUser(name string) (*User, *proto.Fail) { if us.ByName(name) != nil { return nil, &proto.Fail { "name-taken", "", map[string]string {"": name}, @@ -39,44 +41,107 @@ func (us *UserStore) CreateUser(name string) (*User, *proto.Fail) { } } var u User - u.store = us + u.kind = us u.name = name - us.byName[validate.Fold(name)] = &u - u.id = us.world.NewObject(&u) + u.id = proto.GenId() + u.anonymous = true u.Channels = make(map[string]bool) + u.Save() return &u, nil } -func (us *UserStore) ByName(name string) *User { - return us.byName[validate.Fold(name)] +func (us *UserKind) ByName(name string) *User { + switch u := us.world.Lookup("u", name).(type) { + case *User: + return u + default: + return nil + } +} + +func (us *UserKind) DeleteAnonUsers() { + var anon []string + err := us.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("anonymous users")) + if bucket == nil { + return nil + } + bucket.ForEach(func(k, v []byte) error { + anon = append(anon, string(k)) + return nil + }) + return nil + }) + if err != nil { + log.Fatal("error reading database: ", err) + } + for _, id := range anon { + switch u := us.world.GetObject(id).(type) { + case *User: + u.Delete() + } + } +} + +func (us *UserKind) Undata(o proto.Object) object.Object { + var u User + u.kind = us + u.id = o.Id + u.name = o.Fields[""] + u.status = o.Fields["status"] + u.description = o.Fields["description"] + u.anonymous = o.Fields["anonymous"] == "yes" + return &u +} + +func (u *User) Data() proto.Object { + data := u.InfoFor("") + data.Fields["description"] = u.description + return data } func (u *User) Name() string { return u.name } +func (u *User) Handle() string { + return u.Name() +} + func (u *User) Id() string { return u.id } +func (u *User) Save() { + err := u.kind.db.Update(func(tx *bolt.Tx) error { + bucket, _ := tx.CreateBucketIfNotExists([]byte("anonymous users")) + if u.anonymous { + bucket.Put([]byte(u.id), []byte("yes")) + } else { + bucket.Delete([]byte(u.id)) + } + return nil + }) + if err != nil { + log.Fatal("error updating database: ", err) + } + u.kind.world.PutObject(u.id, u) +} + func (u *User) Rename(name string) *proto.Fail { if !validate.Name(name) { return &proto.Fail { "invalid-name", "", map[string]string {"": name}, } } - if validate.Fold(name) == validate.Fold(u.name) { - u.name = name - return nil - } - if u.store.ByName(name) != nil { + if u.kind.ByName(name) != nil && + validate.Fold(name) != validate.Fold(u.name) { return &proto.Fail { "name-taken", "", map[string]string {"": name}, } } - u.store.byName[validate.Fold(u.name)] = nil - u.store.byName[validate.Fold(name)] = u u.name = name + u.Save() return nil } @@ -84,13 +149,18 @@ func (u *User) Delete() { u.Stream.Event(proto.NewCmd("delete", u.id)) u.Stream.UnsubscribeAll() - delete(u.store.byName, validate.Fold(u.name)) - u.store.world.RemoveObject(u.id) - gone := object.Tombstone { u.id, map[string]string {"": u.name, "kind": "u"}, } - u.store.world.PutObject(u.id, gone) + u.kind.world.PutObject(u.id, gone) + err := u.kind.db.Update(func(tx *bolt.Tx) error { + bucket, _ := tx.CreateBucketIfNotExists([]byte("anonymous users")) + bucket.Delete([]byte(u.id)) + return nil + }) + if err != nil { + log.Fatal("error updating database: ", err) + } } func (u *User) InfoFor(uid string) proto.Object { @@ -98,10 +168,14 @@ func (u *User) InfoFor(uid string) proto.Object { if u.status != "" { i["status"] = u.status } - if u.Anonymous { + if u.anonymous { i["anonymous"] = "yes" } else { i["anonymous"] = "no" } return proto.Object {"u", u.id, i} } + +func (u *User) IsAnonymous() bool { + return u.anonymous +} -- cgit v1.2.3