login_sources_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. mockrequire "github.com/derision-test/go-mockgen/testutil/require"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. "gogs.io/gogs/internal/auth"
  14. "gogs.io/gogs/internal/auth/github"
  15. "gogs.io/gogs/internal/auth/pam"
  16. "gogs.io/gogs/internal/errutil"
  17. )
  18. func TestLoginSource_BeforeSave(t *testing.T) {
  19. now := time.Now()
  20. db := &gorm.DB{
  21. Config: &gorm.Config{
  22. SkipDefaultTransaction: true,
  23. NowFunc: func() time.Time {
  24. return now
  25. },
  26. },
  27. }
  28. t.Run("Config has not been set", func(t *testing.T) {
  29. s := &LoginSource{}
  30. err := s.BeforeSave(db)
  31. require.NoError(t, err)
  32. assert.Empty(t, s.Config)
  33. })
  34. t.Run("Config has been set", func(t *testing.T) {
  35. s := &LoginSource{
  36. Provider: pam.NewProvider(&pam.Config{
  37. ServiceName: "pam_service",
  38. }),
  39. }
  40. err := s.BeforeSave(db)
  41. require.NoError(t, err)
  42. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  43. })
  44. }
  45. func TestLoginSource_BeforeCreate(t *testing.T) {
  46. now := time.Now()
  47. db := &gorm.DB{
  48. Config: &gorm.Config{
  49. SkipDefaultTransaction: true,
  50. NowFunc: func() time.Time {
  51. return now
  52. },
  53. },
  54. }
  55. t.Run("CreatedUnix has been set", func(t *testing.T) {
  56. s := &LoginSource{CreatedUnix: 1}
  57. _ = s.BeforeCreate(db)
  58. assert.Equal(t, int64(1), s.CreatedUnix)
  59. assert.Equal(t, int64(0), s.UpdatedUnix)
  60. })
  61. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  62. s := &LoginSource{}
  63. _ = s.BeforeCreate(db)
  64. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  65. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  66. })
  67. }
  68. func Test_loginSources(t *testing.T) {
  69. if testing.Short() {
  70. t.Skip()
  71. }
  72. t.Parallel()
  73. tables := []interface{}{new(LoginSource), new(User)}
  74. db := &loginSources{
  75. DB: initTestDB(t, "loginSources", tables...),
  76. }
  77. for _, tc := range []struct {
  78. name string
  79. test func(*testing.T, *loginSources)
  80. }{
  81. {"Create", loginSourcesCreate},
  82. {"Count", loginSourcesCount},
  83. {"DeleteByID", loginSourcesDeleteByID},
  84. {"GetByID", loginSourcesGetByID},
  85. {"List", loginSourcesList},
  86. {"ResetNonDefault", loginSourcesResetNonDefault},
  87. {"Save", loginSourcesSave},
  88. } {
  89. t.Run(tc.name, func(t *testing.T) {
  90. t.Cleanup(func() {
  91. err := clearTables(t, db.DB, tables...)
  92. require.NoError(t, err)
  93. })
  94. tc.test(t, db)
  95. })
  96. if t.Failed() {
  97. break
  98. }
  99. }
  100. }
  101. func loginSourcesCreate(t *testing.T, db *loginSources) {
  102. ctx := context.Background()
  103. // Create first login source with name "GitHub"
  104. source, err := db.Create(ctx,
  105. CreateLoginSourceOpts{
  106. Type: auth.GitHub,
  107. Name: "GitHub",
  108. Activated: true,
  109. Default: false,
  110. Config: &github.Config{
  111. APIEndpoint: "https://api.github.com",
  112. },
  113. },
  114. )
  115. require.NoError(t, err)
  116. // Get it back and check the Created field
  117. source, err = db.GetByID(ctx, source.ID)
  118. require.NoError(t, err)
  119. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  120. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  121. // Try create second login source with same name should fail
  122. _, err = db.Create(ctx, CreateLoginSourceOpts{Name: source.Name})
  123. expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  124. assert.Equal(t, expErr, err)
  125. }
  126. func loginSourcesCount(t *testing.T, db *loginSources) {
  127. ctx := context.Background()
  128. // Create two login sources, one in database and one as source file.
  129. _, err := db.Create(ctx,
  130. CreateLoginSourceOpts{
  131. Type: auth.GitHub,
  132. Name: "GitHub",
  133. Activated: true,
  134. Default: false,
  135. Config: &github.Config{
  136. APIEndpoint: "https://api.github.com",
  137. },
  138. },
  139. )
  140. require.NoError(t, err)
  141. mock := NewMockLoginSourceFilesStore()
  142. mock.LenFunc.SetDefaultReturn(2)
  143. setMockLoginSourceFilesStore(t, db, mock)
  144. assert.Equal(t, int64(3), db.Count(ctx))
  145. }
  146. func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
  147. ctx := context.Background()
  148. t.Run("delete but in used", func(t *testing.T) {
  149. source, err := db.Create(ctx,
  150. CreateLoginSourceOpts{
  151. Type: auth.GitHub,
  152. Name: "GitHub",
  153. Activated: true,
  154. Default: false,
  155. Config: &github.Config{
  156. APIEndpoint: "https://api.github.com",
  157. },
  158. },
  159. )
  160. require.NoError(t, err)
  161. // Create a user that uses this login source
  162. _, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{
  163. LoginSource: source.ID,
  164. })
  165. require.NoError(t, err)
  166. // Delete the login source will result in error
  167. err = db.DeleteByID(ctx, source.ID)
  168. expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  169. assert.Equal(t, expErr, err)
  170. })
  171. mock := NewMockLoginSourceFilesStore()
  172. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  173. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  174. })
  175. setMockLoginSourceFilesStore(t, db, mock)
  176. // Create a login source with name "GitHub2"
  177. source, err := db.Create(ctx,
  178. CreateLoginSourceOpts{
  179. Type: auth.GitHub,
  180. Name: "GitHub2",
  181. Activated: true,
  182. Default: false,
  183. Config: &github.Config{
  184. APIEndpoint: "https://api.github.com",
  185. },
  186. },
  187. )
  188. require.NoError(t, err)
  189. // Delete a non-existent ID is noop
  190. err = db.DeleteByID(ctx, 9999)
  191. require.NoError(t, err)
  192. // We should be able to get it back
  193. _, err = db.GetByID(ctx, source.ID)
  194. require.NoError(t, err)
  195. // Now delete this login source with ID
  196. err = db.DeleteByID(ctx, source.ID)
  197. require.NoError(t, err)
  198. // We should get token not found error
  199. _, err = db.GetByID(ctx, source.ID)
  200. expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  201. assert.Equal(t, expErr, err)
  202. }
  203. func loginSourcesGetByID(t *testing.T, db *loginSources) {
  204. ctx := context.Background()
  205. mock := NewMockLoginSourceFilesStore()
  206. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  207. if id != 101 {
  208. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  209. }
  210. return &LoginSource{ID: id}, nil
  211. })
  212. setMockLoginSourceFilesStore(t, db, mock)
  213. expConfig := &github.Config{
  214. APIEndpoint: "https://api.github.com",
  215. }
  216. // Create a login source with name "GitHub"
  217. source, err := db.Create(ctx,
  218. CreateLoginSourceOpts{
  219. Type: auth.GitHub,
  220. Name: "GitHub",
  221. Activated: true,
  222. Default: false,
  223. Config: expConfig,
  224. },
  225. )
  226. require.NoError(t, err)
  227. // Get the one in the database and test the read/write hooks
  228. source, err = db.GetByID(ctx, source.ID)
  229. require.NoError(t, err)
  230. assert.Equal(t, expConfig, source.Provider.Config())
  231. // Get the one in source file store
  232. _, err = db.GetByID(ctx, 101)
  233. require.NoError(t, err)
  234. }
  235. func loginSourcesList(t *testing.T, db *loginSources) {
  236. ctx := context.Background()
  237. mock := NewMockLoginSourceFilesStore()
  238. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOpts) []*LoginSource {
  239. if opts.OnlyActivated {
  240. return []*LoginSource{
  241. {ID: 1},
  242. }
  243. }
  244. return []*LoginSource{
  245. {ID: 1},
  246. {ID: 2},
  247. }
  248. })
  249. setMockLoginSourceFilesStore(t, db, mock)
  250. // Create two login sources in database, one activated and the other one not
  251. _, err := db.Create(ctx,
  252. CreateLoginSourceOpts{
  253. Type: auth.PAM,
  254. Name: "PAM",
  255. Config: &pam.Config{
  256. ServiceName: "PAM",
  257. },
  258. },
  259. )
  260. require.NoError(t, err)
  261. _, err = db.Create(ctx,
  262. CreateLoginSourceOpts{
  263. Type: auth.GitHub,
  264. Name: "GitHub",
  265. Activated: true,
  266. Config: &github.Config{
  267. APIEndpoint: "https://api.github.com",
  268. },
  269. },
  270. )
  271. require.NoError(t, err)
  272. // List all login sources
  273. sources, err := db.List(ctx, ListLoginSourceOpts{})
  274. require.NoError(t, err)
  275. assert.Equal(t, 4, len(sources), "number of sources")
  276. // Only list activated login sources
  277. sources, err = db.List(ctx, ListLoginSourceOpts{OnlyActivated: true})
  278. require.NoError(t, err)
  279. assert.Equal(t, 2, len(sources), "number of sources")
  280. }
  281. func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
  282. ctx := context.Background()
  283. mock := NewMockLoginSourceFilesStore()
  284. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOpts) []*LoginSource {
  285. mockFile := NewMockLoginSourceFileStore()
  286. mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
  287. assert.Equal(t, "is_default", name)
  288. assert.Equal(t, "false", value)
  289. })
  290. return []*LoginSource{
  291. {
  292. File: mockFile,
  293. },
  294. }
  295. })
  296. setMockLoginSourceFilesStore(t, db, mock)
  297. // Create two login sources both have default on
  298. source1, err := db.Create(ctx,
  299. CreateLoginSourceOpts{
  300. Type: auth.PAM,
  301. Name: "PAM",
  302. Default: true,
  303. Config: &pam.Config{
  304. ServiceName: "PAM",
  305. },
  306. },
  307. )
  308. require.NoError(t, err)
  309. source2, err := db.Create(ctx,
  310. CreateLoginSourceOpts{
  311. Type: auth.GitHub,
  312. Name: "GitHub",
  313. Activated: true,
  314. Default: true,
  315. Config: &github.Config{
  316. APIEndpoint: "https://api.github.com",
  317. },
  318. },
  319. )
  320. require.NoError(t, err)
  321. // Set source 1 as default
  322. err = db.ResetNonDefault(ctx, source1)
  323. require.NoError(t, err)
  324. // Verify the default state
  325. source1, err = db.GetByID(ctx, source1.ID)
  326. require.NoError(t, err)
  327. assert.True(t, source1.IsDefault)
  328. source2, err = db.GetByID(ctx, source2.ID)
  329. require.NoError(t, err)
  330. assert.False(t, source2.IsDefault)
  331. }
  332. func loginSourcesSave(t *testing.T, db *loginSources) {
  333. ctx := context.Background()
  334. t.Run("save to database", func(t *testing.T) {
  335. // Create a login source with name "GitHub"
  336. source, err := db.Create(ctx,
  337. CreateLoginSourceOpts{
  338. Type: auth.GitHub,
  339. Name: "GitHub",
  340. Activated: true,
  341. Default: false,
  342. Config: &github.Config{
  343. APIEndpoint: "https://api.github.com",
  344. },
  345. },
  346. )
  347. require.NoError(t, err)
  348. source.IsActived = false
  349. source.Provider = github.NewProvider(&github.Config{
  350. APIEndpoint: "https://api2.github.com",
  351. })
  352. err = db.Save(ctx, source)
  353. require.NoError(t, err)
  354. source, err = db.GetByID(ctx, source.ID)
  355. require.NoError(t, err)
  356. assert.False(t, source.IsActived)
  357. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  358. })
  359. t.Run("save to file", func(t *testing.T) {
  360. mockFile := NewMockLoginSourceFileStore()
  361. source := &LoginSource{
  362. Provider: github.NewProvider(&github.Config{
  363. APIEndpoint: "https://api.github.com",
  364. }),
  365. File: mockFile,
  366. }
  367. err := db.Save(ctx, source)
  368. require.NoError(t, err)
  369. mockrequire.Called(t, mockFile.SaveFunc)
  370. })
  371. }