2
0

notices_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // Copyright 2023 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package database
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gorm.io/gorm"
  12. )
  13. func TestNotice_BeforeCreate(t *testing.T) {
  14. now := time.Now()
  15. db := &gorm.DB{
  16. Config: &gorm.Config{
  17. SkipDefaultTransaction: true,
  18. NowFunc: func() time.Time {
  19. return now
  20. },
  21. },
  22. }
  23. t.Run("CreatedUnix has been set", func(t *testing.T) {
  24. notice := &Notice{
  25. CreatedUnix: 1,
  26. }
  27. _ = notice.BeforeCreate(db)
  28. assert.Equal(t, int64(1), notice.CreatedUnix)
  29. })
  30. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  31. notice := &Notice{}
  32. _ = notice.BeforeCreate(db)
  33. assert.Equal(t, db.NowFunc().Unix(), notice.CreatedUnix)
  34. })
  35. }
  36. func TestNotice_AfterFind(t *testing.T) {
  37. now := time.Now()
  38. db := &gorm.DB{
  39. Config: &gorm.Config{
  40. SkipDefaultTransaction: true,
  41. NowFunc: func() time.Time {
  42. return now
  43. },
  44. },
  45. }
  46. notice := &Notice{
  47. CreatedUnix: now.Unix(),
  48. }
  49. _ = notice.AfterFind(db)
  50. assert.Equal(t, notice.CreatedUnix, notice.Created.Unix())
  51. }
  52. func TestNotices(t *testing.T) {
  53. if testing.Short() {
  54. t.Skip()
  55. }
  56. t.Parallel()
  57. ctx := context.Background()
  58. db := &noticesStore{
  59. DB: newTestDB(t, "noticesStore"),
  60. }
  61. for _, tc := range []struct {
  62. name string
  63. test func(t *testing.T, ctx context.Context, db *noticesStore)
  64. }{
  65. {"Create", noticesCreate},
  66. {"DeleteByIDs", noticesDeleteByIDs},
  67. {"DeleteAll", noticesDeleteAll},
  68. {"List", noticesList},
  69. {"Count", noticesCount},
  70. } {
  71. t.Run(tc.name, func(t *testing.T) {
  72. t.Cleanup(func() {
  73. err := clearTables(t, db.DB)
  74. require.NoError(t, err)
  75. })
  76. tc.test(t, ctx, db)
  77. })
  78. if t.Failed() {
  79. break
  80. }
  81. }
  82. }
  83. func noticesCreate(t *testing.T, ctx context.Context, db *noticesStore) {
  84. err := db.Create(ctx, NoticeTypeRepository, "test")
  85. require.NoError(t, err)
  86. count := db.Count(ctx)
  87. assert.Equal(t, int64(1), count)
  88. }
  89. func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *noticesStore) {
  90. err := db.Create(ctx, NoticeTypeRepository, "test")
  91. require.NoError(t, err)
  92. notices, err := db.List(ctx, 1, 10)
  93. require.NoError(t, err)
  94. ids := make([]int64, 0, len(notices))
  95. for _, notice := range notices {
  96. ids = append(ids, notice.ID)
  97. }
  98. // Non-existing IDs should be ignored
  99. ids = append(ids, 404)
  100. err = db.DeleteByIDs(ctx, ids...)
  101. require.NoError(t, err)
  102. count := db.Count(ctx)
  103. assert.Equal(t, int64(0), count)
  104. }
  105. func noticesDeleteAll(t *testing.T, ctx context.Context, db *noticesStore) {
  106. err := db.Create(ctx, NoticeTypeRepository, "test")
  107. require.NoError(t, err)
  108. err = db.DeleteAll(ctx)
  109. require.NoError(t, err)
  110. count := db.Count(ctx)
  111. assert.Equal(t, int64(0), count)
  112. }
  113. func noticesList(t *testing.T, ctx context.Context, db *noticesStore) {
  114. err := db.Create(ctx, NoticeTypeRepository, "test 1")
  115. require.NoError(t, err)
  116. err = db.Create(ctx, NoticeTypeRepository, "test 2")
  117. require.NoError(t, err)
  118. got1, err := db.List(ctx, 1, 1)
  119. require.NoError(t, err)
  120. require.Len(t, got1, 1)
  121. got2, err := db.List(ctx, 2, 1)
  122. require.NoError(t, err)
  123. require.Len(t, got2, 1)
  124. assert.True(t, got1[0].ID > got2[0].ID)
  125. got, err := db.List(ctx, 1, 3)
  126. require.NoError(t, err)
  127. require.Len(t, got, 2)
  128. }
  129. func noticesCount(t *testing.T, ctx context.Context, db *noticesStore) {
  130. count := db.Count(ctx)
  131. assert.Equal(t, int64(0), count)
  132. err := db.Create(ctx, NoticeTypeRepository, "test")
  133. require.NoError(t, err)
  134. count = db.Count(ctx)
  135. assert.Equal(t, int64(1), count)
  136. }