diff options
Diffstat (limited to 'server/channel')
| -rw-r--r-- | server/channel/channel.go | 156 | ||||
| -rw-r--r-- | server/channel/command.go | 12 |
2 files changed, 127 insertions, 41 deletions
diff --git a/server/channel/channel.go b/server/channel/channel.go index be59fcf..81d5439 100644 --- a/server/channel/channel.go +++ b/server/channel/channel.go @@ -7,6 +7,7 @@ import ( "citrons.xyz/talk/server/validate" "citrons.xyz/talk/server/user" "strings" + "strconv" "slices" "sort" "maps" @@ -27,8 +28,6 @@ type Channel struct { name string isDirect bool Stream session.Stream - byId map[string]int - messages []proto.Object } func Kind(world *object.World) *ChannelKind { @@ -50,19 +49,7 @@ func (cs *ChannelKind) CreateChannel(name string) (*Channel, *proto.Fail) { c.kind = cs c.name = name c.id = proto.GenId() - c.byId = make(map[string]int) - - err := cs.db.Update(func(tx *bolt.Tx) error { - chm, _ := tx.CreateBucketIfNotExists([]byte("channel membership")) - chm.CreateBucket([]byte(c.id)) - tx.CreateBucketIfNotExists([]byte("user channels")) - return nil - }) - if err != nil { - log.Fatal("error updating database: ", err) - } c.SetDefaultMembership(DefaultMembership) - c.Save() return &c, nil } @@ -71,6 +58,23 @@ func DirectHandle(among map[string]bool) string { return strings.Join(slices.Sorted(maps.Keys(among)), "\x00") } +func (cs *ChannelKind) GetDirect(among map[string]bool) *Channel { + handle := DirectHandle(among) + switch ch := cs.world.Lookup("direct-channel", handle).(type) { + case *Channel: + return ch + } + var c Channel + c.isDirect = true + c.kind = cs + c.id = proto.GenId() + for member, _ := range among { + c.SetMembership(member, DefaultMembership) + } + c.Save() + return &c +} + func (cs *ChannelKind) UserChannels(uid string) map[string]bool { result := make(map[string]bool) err := cs.db.View(func(tx *bolt.Tx) error { @@ -93,24 +97,6 @@ func (cs *ChannelKind) UserChannels(uid string) map[string]bool { return result } -func (cs *ChannelKind) GetDirect(among map[string]bool) *Channel { - handle := DirectHandle(among) - switch ch := cs.world.Lookup("direct-channel", handle).(type) { - case *Channel: - return ch - } - var c Channel - c.isDirect = true - c.kind = cs - c.id = proto.GenId() - c.byId = make(map[string]int) - for member, _ := range among { - c.SetMembership(member, DefaultMembership) - } - c.Save() - return &c -} - func (cs *ChannelKind) ByName(name string) *Channel { switch ch := cs.world.Lookup("channel", name).(type) { case *Channel: @@ -126,7 +112,6 @@ func (cs *ChannelKind) Undata(o proto.Object) object.Object { c.id = o.Id c.name = o.Fields[""] c.isDirect = o.Kind == "direct-channel" - c.byId = make(map[string]int) return &c } @@ -196,8 +181,30 @@ func (c *Channel) Rename(name string) *proto.Fail { func (c *Channel) Put(m proto.Object) proto.Object { m.Id = proto.GenId() m.Fields["t"] = proto.Timestamp() - c.byId[m.Id] = len(c.messages) - c.messages = append(c.messages, m) + err := c.kind.db.Update(func(tx *bolt.Tx) error { + history, _ := tx.CreateBucketIfNotExists([]byte("message history")) + channel, _ := history.CreateBucketIfNotExists([]byte(c.id)) + ids, _ := channel.CreateBucketIfNotExists([]byte("ids")) + + index, err := channel.NextSequence() + if err != nil { + return err + } + key := []byte(strconv.Itoa(int(index))) + + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + proto.WriteObject(writer, m) + writer.Flush() + err = channel.Put(key, buf.Bytes()) + if err != nil { + return err + } + return ids.Put([]byte(m.Id), key) + }) + if err != nil { + log.Fatal("error updating database: ", err) + } for s, _ := range c.Stream.Subscribers() { if m.Fields["f"] == s.UserId { continue @@ -209,6 +216,73 @@ func (c *Channel) Put(m proto.Object) proto.Object { return m } +func (c *Channel) HistorySize() int { + var size int + err := c.kind.db.View(func(tx *bolt.Tx) error { + history := tx.Bucket([]byte("message history")) + if history == nil { + return nil + } + channel := history.Bucket([]byte(c.id)) + if channel == nil { + return nil + } + size = int(channel.Sequence()) + return nil + }) + if err != nil { + log.Fatal("error reading database: ", err) + } + return size +} + +func (c *Channel) MessageIndex(mid string) (index int, ok bool) { + err := c.kind.db.View(func(tx *bolt.Tx) error { + history := tx.Bucket([]byte("message history")) + if history == nil { + return nil + } + channel := history.Bucket([]byte(c.id)) + if channel == nil { + return nil + } + ids := channel.Bucket([]byte("ids")) + data := ids.Get([]byte(mid)) + ok = data != nil + index, _ = strconv.Atoi(string(data)) + return nil + }) + if err != nil { + log.Fatal("error reading database: ", err) + } + return index, ok +} + +func (c *Channel) History(min, max int) []proto.Object { + var result []proto.Object + err := c.kind.db.View(func(tx *bolt.Tx) error { + history := tx.Bucket([]byte("message history")) + channel := history.Bucket([]byte(c.id)) + + for index := min; index < max; index++ { + data := channel.Get([]byte(strconv.Itoa(index))) + if data == nil { + continue + } + m, err := proto.ReadObject(bufio.NewReader(bytes.NewReader(data))) + if err != nil { + panic(err) + } + result = append(result, m) + } + return nil + }) + if err != nil { + log.Fatal("error reading database: ", err) + } + return result +} + func (c *Channel) Join(u *user.User) *proto.Fail { if c.GetMembership(u.Id()).Yes { return nil @@ -231,7 +305,13 @@ func (c *Channel) Members() map[string]Membership { result := make(map[string]Membership) err := c.kind.db.View(func(tx *bolt.Tx) error { channels := tx.Bucket([]byte("channel membership")) + if channels == nil { + return nil + } members := channels.Bucket([]byte(c.id)) + if members == nil { + return nil + } return members.ForEach(func(k, v []byte) error { var mship Membership o, _ := proto.ReadObject(bufio.NewReader(bytes.NewReader(v))) @@ -249,7 +329,13 @@ func (c *Channel) GetMembership(uid string) Membership { var mship Membership err := c.kind.db.View(func(tx *bolt.Tx) error { channels := tx.Bucket([]byte("channel membership")) + if channels == nil { + return nil + } members := channels.Bucket([]byte(c.id)) + if members == nil { + return nil + } data := members.Get([]byte(uid)) if data != nil { o, _ := proto.ReadObject(bufio.NewReader(bytes.NewReader(data))) diff --git a/server/channel/command.go b/server/channel/command.go index 1184a5a..04e45a2 100644 --- a/server/channel/command.go +++ b/server/channel/command.go @@ -21,7 +21,7 @@ func (c *Channel) SendRequest(r session.Request) { switch k { case "": case "reply": - _, ok := c.byId[v] + _, ok := c.MessageIndex(v) if !ok { r.Reply(proto.Fail{"bad-reply", "", nil}.Cmd()) return @@ -141,7 +141,7 @@ func (c *Channel) SendRequest(r session.Request) { var max int switch h.Kind { case "latest": - max = len(c.messages) + max = c.HistorySize() min = max - 20 case "before", "around", "after", "at": var id string @@ -154,7 +154,7 @@ func (c *Channel) SendRequest(r session.Request) { return } } - i, ok := c.byId[id] + i, ok := c.MessageIndex(id) if !ok { r.Reply(proto.Fail{"bad-target", "", nil}.Cmd()) return @@ -180,8 +180,8 @@ func (c *Channel) SendRequest(r session.Request) { if min < 0 { min = 0 } - if max > len(c.messages) { - max = len(c.messages) + if max > c.HistorySize() { + max = c.HistorySize() } p := c.GetMembership(r.From.UserId) @@ -190,7 +190,7 @@ func (c *Channel) SendRequest(r session.Request) { return } cmd := proto.NewCmd("history", c.id) - cmd.Args = c.messages[min:max] + cmd.Args = c.History(min, max) r.Reply(cmd) case "membership": |
