Skip to content

Commit 8f8a89f

Browse files
committed
fix(catalog): data races in test mocks
The tests in catalog/internal/catalog currently fail under the race detector because some of the mocks are safe for concurrent use. This adds mutexes where necessary. Signed-off-by: Paul Boyd <paul@pboyd.io>
1 parent 052af44 commit 8f8a89f

File tree

1 file changed

+88
-17
lines changed

1 file changed

+88
-17
lines changed

catalog/internal/catalog/catalog_test.go

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -382,16 +382,14 @@ func TestLoadCatalogSourcesWithMockRepositories(t *testing.T) {
382382
// Wait a bit for the goroutine to process
383383
time.Sleep(100 * time.Millisecond)
384384

385-
mockModelRepo.mu.RLock()
386-
defer mockModelRepo.mu.RUnlock()
387-
388385
// Verify that the model was saved
389-
if len(mockModelRepo.SavedModels) != 1 {
390-
t.Errorf("Expected 1 model to be saved, got %d", len(mockModelRepo.SavedModels))
386+
savedModels := mockModelRepo.GetSavedModels()
387+
if len(savedModels) != 1 {
388+
t.Errorf("Expected 1 model to be saved, got %d", len(savedModels))
391389
}
392390

393-
if len(mockModelRepo.SavedModels) > 0 {
394-
savedModel := mockModelRepo.SavedModels[0]
391+
if len(savedModels) > 0 {
392+
savedModel := savedModels[0]
395393
if savedModel.GetAttributes() == nil || savedModel.GetAttributes().Name == nil {
396394
t.Error("Saved model should have attributes with name")
397395
} else if *savedModel.GetAttributes().Name != "test-model" {
@@ -486,12 +484,10 @@ func TestLoadCatalogSourcesWithRepositoryErrors(t *testing.T) {
486484
// Wait for processing
487485
time.Sleep(100 * time.Millisecond)
488486

489-
mockModelRepo.mu.RLock()
490-
defer mockModelRepo.mu.RUnlock()
491-
492487
// Verify that no models were saved due to the error
493-
if len(mockModelRepo.SavedModels) != 0 {
494-
t.Errorf("Expected 0 models to be saved due to error, got %d", len(mockModelRepo.SavedModels))
488+
savedModels := mockModelRepo.GetSavedModels()
489+
if len(savedModels) != 0 {
490+
t.Errorf("Expected 0 models to be saved due to error, got %d", len(savedModels))
495491
}
496492
}
497493

@@ -563,12 +559,13 @@ func TestLoadCatalogSourcesWithNilEnabled(t *testing.T) {
563559
time.Sleep(100 * time.Millisecond)
564560

565561
// Verify that the model WAS saved (because nil Enabled is treated as enabled)
566-
if len(mockModelRepo.SavedModels) != 1 {
567-
t.Errorf("Expected 1 model to be saved (nil Enabled should be treated as enabled), got %d", len(mockModelRepo.SavedModels))
562+
savedModels := mockModelRepo.GetSavedModels()
563+
if len(savedModels) != 1 {
564+
t.Errorf("Expected 1 model to be saved (nil Enabled should be treated as enabled), got %d", len(savedModels))
568565
}
569566

570-
if len(mockModelRepo.SavedModels) > 0 {
571-
savedModel := mockModelRepo.SavedModels[0]
567+
if len(savedModels) > 0 {
568+
savedModel := savedModels[0]
572569
if savedModel.GetAttributes() == nil || savedModel.GetAttributes().Name == nil {
573570
t.Error("Saved model should have attributes with name")
574571
} else if *savedModel.GetAttributes().Name != "test-model-nil-enabled" {
@@ -716,20 +713,45 @@ func (m *MockCatalogModelRepository) Save(model dbmodels.CatalogModel) (dbmodels
716713
}
717714

718715
func (m *MockCatalogModelRepository) DeleteBySource(sourceID string) error {
716+
m.mu.Lock()
717+
defer m.mu.Unlock()
719718
// Mock implementation - no-op for testing
720719
return nil
721720
}
722721

723722
func (m *MockCatalogModelRepository) DeleteByID(id int32) error {
723+
m.mu.Lock()
724+
defer m.mu.Unlock()
724725
// Mock implementation - no-op for testing
725726
return nil
726727
}
727728

728729
func (m *MockCatalogModelRepository) GetDistinctSourceIDs() ([]string, error) {
730+
m.mu.RLock()
731+
defer m.mu.RUnlock()
729732
// Mock implementation - return empty list by default
730733
return []string{}, nil
731734
}
732735

736+
// GetSavedModels returns a copy of the saved models slice in a thread-safe manner.
737+
// This should be used by tests instead of directly accessing SavedModels field.
738+
func (m *MockCatalogModelRepository) GetSavedModels() []dbmodels.CatalogModel {
739+
m.mu.RLock()
740+
defer m.mu.RUnlock()
741+
// Return a copy to prevent external modifications
742+
result := make([]dbmodels.CatalogModel, len(m.SavedModels))
743+
copy(result, m.SavedModels)
744+
return result
745+
}
746+
747+
// Reset clears all saved models and resets the NextID counter in a thread-safe manner.
748+
func (m *MockCatalogModelRepository) Reset() {
749+
m.mu.Lock()
750+
defer m.mu.Unlock()
751+
m.SavedModels = []dbmodels.CatalogModel{}
752+
m.NextID = 0
753+
}
754+
733755
// MockCatalogModelArtifactRepository mocks the CatalogModelArtifactRepository interface.
734756
type MockCatalogModelArtifactRepository struct {
735757
mu sync.RWMutex
@@ -778,6 +800,23 @@ func (m *MockCatalogModelArtifactRepository) Save(modelArtifact dbmodels.Catalog
778800
return savedArtifact, nil
779801
}
780802

803+
// GetSavedArtifacts returns a copy of the saved artifacts slice in a thread-safe manner.
804+
func (m *MockCatalogModelArtifactRepository) GetSavedArtifacts() []dbmodels.CatalogModelArtifact {
805+
m.mu.RLock()
806+
defer m.mu.RUnlock()
807+
result := make([]dbmodels.CatalogModelArtifact, len(m.SavedArtifacts))
808+
copy(result, m.SavedArtifacts)
809+
return result
810+
}
811+
812+
// Reset clears all saved artifacts and resets the NextID counter in a thread-safe manner.
813+
func (m *MockCatalogModelArtifactRepository) Reset() {
814+
m.mu.Lock()
815+
defer m.mu.Unlock()
816+
m.SavedArtifacts = []dbmodels.CatalogModelArtifact{}
817+
m.NextID = 0
818+
}
819+
781820
// MockCatalogMetricsArtifactRepository mocks the CatalogMetricsArtifactRepository interface.
782821
type MockCatalogMetricsArtifactRepository struct {
783822
mu sync.RWMutex
@@ -854,6 +893,23 @@ func (m *MockCatalogMetricsArtifactRepository) BatchSave(metricsArtifacts []dbmo
854893
return savedArtifacts, nil
855894
}
856895

896+
// GetSavedMetrics returns a copy of the saved metrics slice in a thread-safe manner.
897+
func (m *MockCatalogMetricsArtifactRepository) GetSavedMetrics() []dbmodels.CatalogMetricsArtifact {
898+
m.mu.RLock()
899+
defer m.mu.RUnlock()
900+
result := make([]dbmodels.CatalogMetricsArtifact, len(m.SavedMetrics))
901+
copy(result, m.SavedMetrics)
902+
return result
903+
}
904+
905+
// Reset clears all saved metrics and resets the NextID counter in a thread-safe manner.
906+
func (m *MockCatalogMetricsArtifactRepository) Reset() {
907+
m.mu.Lock()
908+
defer m.mu.Unlock()
909+
m.SavedMetrics = []dbmodels.CatalogMetricsArtifact{}
910+
m.NextID = 0
911+
}
912+
857913
// MockCatalogArtifactRepository mocks the CatalogArtifactRepository interface.
858914
type MockCatalogArtifactRepository struct {
859915
mu sync.RWMutex
@@ -953,10 +1009,13 @@ func (m *MockPropertyOptionsRepository) SetMockOptions(t dbmodels.PropertyOption
9531009

9541010
// MockCatalogSourceRepository mocks the CatalogSourceRepository interface.
9551011
type MockCatalogSourceRepository struct {
1012+
mu sync.RWMutex
9561013
Sources []dbmodels.CatalogSource
9571014
}
9581015

9591016
func (m *MockCatalogSourceRepository) GetBySourceID(sourceID string) (dbmodels.CatalogSource, error) {
1017+
m.mu.RLock()
1018+
defer m.mu.RUnlock()
9601019
for _, s := range m.Sources {
9611020
if attrs := s.GetAttributes(); attrs != nil && attrs.Name != nil && *attrs.Name == sourceID {
9621021
return s, nil
@@ -966,19 +1025,31 @@ func (m *MockCatalogSourceRepository) GetBySourceID(sourceID string) (dbmodels.C
9661025
}
9671026

9681027
func (m *MockCatalogSourceRepository) Save(source dbmodels.CatalogSource) (dbmodels.CatalogSource, error) {
1028+
m.mu.Lock()
1029+
defer m.mu.Unlock()
9691030
m.Sources = append(m.Sources, source)
9701031
return source, nil
9711032
}
9721033

9731034
func (m *MockCatalogSourceRepository) Delete(sourceID string) error {
1035+
m.mu.Lock()
1036+
defer m.mu.Unlock()
1037+
// Mock implementation - no-op for testing
9741038
return nil
9751039
}
9761040

9771041
func (m *MockCatalogSourceRepository) GetAll() ([]dbmodels.CatalogSource, error) {
978-
return m.Sources, nil
1042+
m.mu.RLock()
1043+
defer m.mu.RUnlock()
1044+
// Return a copy to prevent external modifications
1045+
result := make([]dbmodels.CatalogSource, len(m.Sources))
1046+
copy(result, m.Sources)
1047+
return result, nil
9791048
}
9801049

9811050
func (m *MockCatalogSourceRepository) GetAllStatuses() (map[string]dbmodels.SourceStatus, error) {
1051+
m.mu.RLock()
1052+
defer m.mu.RUnlock()
9821053
result := make(map[string]dbmodels.SourceStatus)
9831054
for _, source := range m.Sources {
9841055
if attrs := source.GetAttributes(); attrs != nil && attrs.Name != nil {

0 commit comments

Comments
 (0)