package web import ( "bytes" "encoding/csv" "fmt" "html/template" "log" "net/http" "strconv" "strings" "time" root "marmic/servicetrade-toolbox" "marmic/servicetrade-toolbox/internal/api" "marmic/servicetrade-toolbox/internal/middleware" "github.com/gorilla/csrf" ) // RoleAssignmentResult captures the outcome of assigning a role to a user. type RoleAssignmentResult struct { Token string RoleID int Role string Success bool Message string } // UserImportResult represents the processing result for a single CSV row. type UserImportResult struct { Row int Username string Email string UserID int Created bool Error string RoleAssignments []RoleAssignmentResult ProcessingTime time.Duration ServiceLineIDs []int AdditionalFields map[string]string } // UserImportSummary aggregates statistics about a CSV import run. type UserImportSummary struct { TotalRows int UsersCreated int RowsFailed int RoleAssignments int RoleAssignmentErrors int ProcessedFilename string ProcessedAt time.Time } // UserUpdateResult represents the processing result for a single update row. type UserUpdateResult struct { Row int Username string Email string UserID int Updated bool Error string UpdatedFields []string LookupMethod string RoleAssignments []RoleAssignmentResult ProcessingTime time.Duration } // UserUpdateSummary aggregates statistics about a CSV update run. type UserUpdateSummary struct { TotalRows int UsersUpdated int RowsFailed int RoleAssignments int RoleAssignmentErrors int LookupsByUsername int LookupsByID int ProcessedFilename string ProcessedAt time.Time } // UsersHandler renders the user management page. func UsersHandler(w http.ResponseWriter, r *http.Request) { session, ok := r.Context().Value(middleware.SessionKey).(*api.Session) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } data := baseUsersPageData(r) data["Title"] = "Users" data["Session"] = session if roles, err := session.ListRoles(); err != nil { log.Printf("UsersHandler: unable to load roles: %v", err) data["RolesError"] = err.Error() } else { data["Roles"] = roles } renderUsersPage(w, r, data) } // UsersUploadHandler processes uploaded CSV files to create users and assign roles. func UsersUploadHandler(w http.ResponseWriter, r *http.Request) { start := time.Now() session, ok := r.Context().Value(middleware.SessionKey).(*api.Session) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } if err := r.ParseMultipartForm(10 << 20); err != nil { http.Error(w, fmt.Sprintf("Unable to parse upload: %v", err), http.StatusBadRequest) return } file, header, err := r.FormFile("csvFile") if err != nil { http.Error(w, fmt.Sprintf("Unable to read file: %v", err), http.StatusBadRequest) return } defer file.Close() reader := csv.NewReader(file) reader.FieldsPerRecord = -1 reader.TrimLeadingSpace = true rows, err := reader.ReadAll() if err != nil { http.Error(w, fmt.Sprintf("Unable to read CSV: %v", err), http.StatusBadRequest) return } data := baseUsersPageData(r) data["Title"] = "Users" data["Session"] = session if len(rows) < 2 { data["FlashError"] = "CSV must include a header row and at least one data row." renderUsersPage(w, r, data) return } headerMap := buildHeaderIndex(rows[0]) required := []string{"username", "firstname", "lastname", "email", "password", "companyid", "locationid"} if missing := missingHeaders(headerMap, required); len(missing) > 0 { data["FlashError"] = fmt.Sprintf("Missing required column headers: %s", strings.Join(missing, ", ")) renderUsersPage(w, r, data) return } roles, rolesErr := session.ListRoles() if rolesErr != nil { log.Printf("UsersUploadHandler: unable to load roles: %v", rolesErr) data["RolesError"] = fmt.Sprintf("Unable to load roles: %v. Role names in the CSV will not resolve.", rolesErr) } else { data["Roles"] = roles } roleIndexByName := buildRoleNameIndex(roles) results := make([]UserImportResult, 0, len(rows)-1) summary := UserImportSummary{ TotalRows: len(rows) - 1, ProcessedAt: time.Now(), ProcessedFilename: header.Filename, } for rowIdx, row := range rows[1:] { rowStart := time.Now() result := UserImportResult{ Row: rowIdx + 2, Username: getValue(row, headerMap, "username"), Email: getValue(row, headerMap, "email"), AdditionalFields: map[string]string{}, } if rowIsEmpty(row) { continue } if err := validateRequiredRowFields(row, headerMap, required); err != nil { result.Error = err.Error() result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } companyID, err := strconv.Atoi(getValue(row, headerMap, "companyid")) if err != nil { result.Error = fmt.Sprintf("invalid companyId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } locationID, err := strconv.Atoi(getValue(row, headerMap, "locationid")) if err != nil { result.Error = fmt.Sprintf("invalid locationId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } isSales, err := parseOptionalBool(getValue(row, headerMap, "issales")) if err != nil { result.Error = fmt.Sprintf("invalid isSales value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } mfaRequired, err := parseOptionalBool(getValue(row, headerMap, "mfarequired")) if err != nil { result.Error = fmt.Sprintf("invalid mfaRequired value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } managerID, err := parseOptionalInt(getValue(row, headerMap, "managerid")) if err != nil { result.Error = fmt.Sprintf("invalid managerId value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } serviceLineIDs, err := parseIntList(getValue(row, headerMap, "servicelineids")) if err != nil { result.Error = fmt.Sprintf("invalid serviceLineIds value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } result.ServiceLineIDs = serviceLineIDs payload := api.CreateUserPayload{ Username: result.Username, FirstName: getValue(row, headerMap, "firstname"), LastName: getValue(row, headerMap, "lastname"), Password: getValue(row, headerMap, "password"), Email: result.Email, Phone: getValue(row, headerMap, "phone"), CompanyID: companyID, LocationID: locationID, Details: getValue(row, headerMap, "details"), Status: getValue(row, headerMap, "status"), Timezone: getValue(row, headerMap, "timezone"), ServiceLineIDs: serviceLineIDs, } if isSales != nil { payload.IsSales = isSales } if mfaRequired != nil { payload.MFARequired = mfaRequired } if managerID != nil { payload.ManagerID = managerID } userRecord, err := session.CreateUser(payload) if err != nil { result.Error = fmt.Sprintf("failed to create user: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } result.Created = true result.UserID = userRecord.ID summary.UsersCreated++ roleTokens := parseRoleTokens(getValue(row, headerMap, "roles")) if len(roleTokens) > 0 { assignments := make([]RoleAssignmentResult, 0, len(roleTokens)) for _, token := range roleTokens { assignResult := RoleAssignmentResult{Token: token} role, resolved, message := resolveRoleToken(token, roleIndexByName) if !resolved { assignResult.Message = message assignments = append(assignments, assignResult) summary.RoleAssignmentErrors++ continue } assignResult.RoleID = role.ID assignResult.Role = role.Name if err := session.AssignRoleToUser(userRecord.ID, role.ID); err != nil { assignResult.Message = fmt.Sprintf("failed to assign role: %v", err) summary.RoleAssignmentErrors++ } else { assignResult.Success = true assignResult.Message = "assigned" summary.RoleAssignments++ } assignments = append(assignments, assignResult) } result.RoleAssignments = assignments } result.ProcessingTime = time.Since(rowStart) results = append(results, result) } data["ImportResults"] = results data["ImportSummary"] = summary data["FlashSuccess"] = fmt.Sprintf("Processed %d row(s) in %s.", summary.TotalRows, time.Since(start).Round(time.Millisecond)) renderUsersPage(w, r, data) } // UsersUpdateHandler renders the user update page. func UsersUpdateHandler(w http.ResponseWriter, r *http.Request) { session, ok := r.Context().Value(middleware.SessionKey).(*api.Session) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } data := baseUsersPageData(r) data["Title"] = "User Updates" data["Session"] = session if roles, err := session.ListRoles(); err != nil { log.Printf("UsersUpdateHandler: unable to load roles: %v", err) data["RolesError"] = err.Error() } else { data["Roles"] = roles } renderUsersUpdatePage(w, r, data) } // UsersUpdateUploadHandler processes CSV uploads to update existing users. func UsersUpdateUploadHandler(w http.ResponseWriter, r *http.Request) { start := time.Now() session, ok := r.Context().Value(middleware.SessionKey).(*api.Session) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } if err := r.ParseMultipartForm(10 << 20); err != nil { http.Error(w, fmt.Sprintf("Unable to parse upload: %v", err), http.StatusBadRequest) return } file, header, err := r.FormFile("csvFile") if err != nil { http.Error(w, fmt.Sprintf("Unable to read file: %v", err), http.StatusBadRequest) return } defer file.Close() reader := csv.NewReader(file) reader.FieldsPerRecord = -1 reader.TrimLeadingSpace = true rows, err := reader.ReadAll() if err != nil { http.Error(w, fmt.Sprintf("Unable to read CSV: %v", err), http.StatusBadRequest) return } data := baseUsersPageData(r) data["Title"] = "User Updates" data["Session"] = session if len(rows) < 2 { data["FlashError"] = "CSV must include a header row and at least one data row." renderUsersUpdatePage(w, r, data) return } headerMap := buildHeaderIndex(rows[0]) if _, ok := headerMap["username"]; !ok { if _, ok := headerMap["userid"]; !ok { data["FlashError"] = "CSV must include either a username or userId column to locate users." renderUsersUpdatePage(w, r, data) return } } roles, rolesErr := session.ListRoles() if rolesErr != nil { log.Printf("UsersUpdateUploadHandler: unable to load roles: %v", rolesErr) data["RolesError"] = fmt.Sprintf("Unable to load roles: %v. Role names in the CSV will not resolve.", rolesErr) } else { data["Roles"] = roles } roleIndexByName := buildRoleNameIndex(roles) results := make([]UserUpdateResult, 0, len(rows)-1) summary := UserUpdateSummary{ TotalRows: len(rows) - 1, ProcessedAt: time.Now(), ProcessedFilename: header.Filename, } for rowIdx, row := range rows[1:] { rowStart := time.Now() if rowIsEmpty(row) { continue } result := UserUpdateResult{ Row: rowIdx + 2, Username: getValue(row, headerMap, "username"), Email: getValue(row, headerMap, "email"), } userIDValue := getValue(row, headerMap, "userid") var userID int if userIDValue != "" { parsedID, err := strconv.Atoi(userIDValue) if err != nil || parsedID <= 0 { result.Error = fmt.Sprintf("invalid userId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } userID = parsedID result.UserID = userID result.LookupMethod = "userId" summary.LookupsByID++ } else if result.Username != "" { userRecord, err := session.FindUserByUsername(result.Username) if err != nil { result.Error = fmt.Sprintf("failed to locate user by username: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } userID = userRecord.ID result.UserID = userID result.LookupMethod = "username" summary.LookupsByUsername++ } else { result.Error = "row missing username and userId" result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } payload := api.UpdateUserPayload{} var updatedFields []string if v := getValue(row, headerMap, "firstname"); v != "" { payload.FirstName = stringPtr(v) updatedFields = append(updatedFields, "firstName") } if v := getValue(row, headerMap, "new_username"); v != "" { payload.Username = stringPtr(v) updatedFields = append(updatedFields, "username") } if v := getValue(row, headerMap, "lastname"); v != "" { payload.LastName = stringPtr(v) updatedFields = append(updatedFields, "lastName") } if v := getValue(row, headerMap, "email"); v != "" { payload.Email = stringPtr(v) updatedFields = append(updatedFields, "email") } if v := getValue(row, headerMap, "password"); v != "" { payload.Password = stringPtr(v) updatedFields = append(updatedFields, "password") } if v := getValue(row, headerMap, "phone"); v != "" { payload.Phone = stringPtr(v) updatedFields = append(updatedFields, "phone") } if v := getValue(row, headerMap, "details"); v != "" { payload.Details = stringPtr(v) updatedFields = append(updatedFields, "details") } if v := getValue(row, headerMap, "status"); v != "" { payload.Status = stringPtr(v) updatedFields = append(updatedFields, "status") } if v := getValue(row, headerMap, "timezone"); v != "" { payload.Timezone = stringPtr(v) updatedFields = append(updatedFields, "timezone") } if companyVal := getValue(row, headerMap, "companyid"); companyVal != "" { companyID, err := strconv.Atoi(companyVal) if err != nil { result.Error = fmt.Sprintf("invalid companyId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } payload.CompanyID = intPtr(companyID) updatedFields = append(updatedFields, "companyId") } if locationVal := getValue(row, headerMap, "locationid"); locationVal != "" { locationID, err := strconv.Atoi(locationVal) if err != nil { result.Error = fmt.Sprintf("invalid locationId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } payload.LocationID = intPtr(locationID) updatedFields = append(updatedFields, "locationId") } if managerVal := getValue(row, headerMap, "managerid"); managerVal != "" { managerID, err := strconv.Atoi(managerVal) if err != nil { result.Error = fmt.Sprintf("invalid managerId: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } payload.ManagerID = intPtr(managerID) updatedFields = append(updatedFields, "managerId") } if isSales, err := parseOptionalBool(getValue(row, headerMap, "issales")); err != nil { result.Error = fmt.Sprintf("invalid isSales value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } else if isSales != nil { payload.IsSales = isSales updatedFields = append(updatedFields, "isSales") } if mfaRequired, err := parseOptionalBool(getValue(row, headerMap, "mfarequired")); err != nil { result.Error = fmt.Sprintf("invalid mfaRequired value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } else if mfaRequired != nil { payload.MFARequired = mfaRequired updatedFields = append(updatedFields, "mfaRequired") } if serviceLineIDs, err := parseIntList(getValue(row, headerMap, "servicelineids")); err != nil { result.Error = fmt.Sprintf("invalid serviceLineIds value: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } else if serviceLineIDs != nil { payload.ServiceLineIDs = &serviceLineIDs updatedFields = append(updatedFields, "serviceLineIds") } if len(updatedFields) == 0 { result.Error = "no updatable fields provided in row" result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } if _, err := session.UpdateUser(userID, payload); err != nil { result.Error = fmt.Sprintf("failed to update user: %v", err) result.ProcessingTime = time.Since(rowStart) results = append(results, result) summary.RowsFailed++ continue } result.Updated = true result.UpdatedFields = updatedFields summary.UsersUpdated++ roleTokens := parseRoleTokens(getValue(row, headerMap, "roles")) if len(roleTokens) > 0 { assignments := make([]RoleAssignmentResult, 0, len(roleTokens)) for _, token := range roleTokens { assignResult := RoleAssignmentResult{Token: token} role, resolved, message := resolveRoleToken(token, roleIndexByName) if !resolved { assignResult.Message = message assignments = append(assignments, assignResult) summary.RoleAssignmentErrors++ continue } assignResult.RoleID = role.ID assignResult.Role = role.Name if err := session.AssignRoleToUser(userID, role.ID); err != nil { assignResult.Message = fmt.Sprintf("failed to assign role: %v", err) summary.RoleAssignmentErrors++ } else { assignResult.Success = true assignResult.Message = "assigned" summary.RoleAssignments++ } assignments = append(assignments, assignResult) } result.RoleAssignments = assignments } result.ProcessingTime = time.Since(rowStart) results = append(results, result) } data["UpdateResults"] = results data["UpdateSummary"] = summary if summary.UsersUpdated > 0 { data["FlashSuccess"] = fmt.Sprintf("Updated %d user(s) (processed %d rows in %s).", summary.UsersUpdated, summary.TotalRows, time.Since(start).Round(time.Millisecond)) } else { data["FlashError"] = "No users were updated. Review the row errors below." } renderUsersUpdatePage(w, r, data) } func renderUsersPage(w http.ResponseWriter, r *http.Request, data map[string]interface{}) { renderUsersTemplatePage(w, r, data, "users_content") } func renderUsersUpdatePage(w http.ResponseWriter, r *http.Request, data map[string]interface{}) { renderUsersTemplatePage(w, r, data, "users_update_content") } func renderUsersTemplatePage(w http.ResponseWriter, r *http.Request, data map[string]interface{}, templateName string) { tmpl := root.WebTemplates if r.Header.Get("HX-Request") == "true" { if err := tmpl.ExecuteTemplate(w, templateName, data); err != nil { log.Printf("UsersHandler: template error: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } return } var contentBuf bytes.Buffer if err := tmpl.ExecuteTemplate(&contentBuf, templateName, data); err != nil { log.Printf("UsersHandler: template error: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } data["BodyContent"] = template.HTML(contentBuf.String()) if err := tmpl.ExecuteTemplate(w, "layout.html", data); err != nil { log.Printf("UsersHandler: layout template error: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } } func baseUsersPageData(r *http.Request) map[string]interface{} { data := map[string]interface{}{ "CSRFField": csrf.TemplateField(r), "CSRFToken": csrf.Token(r), } if c, err := r.Cookie("XSRF-TOKEN"); err == nil { data["CSRFCookie"] = c.Value } else if c, err := r.Cookie("XSRF-TOKEN-VALUE"); err == nil { data["CSRFCookie"] = c.Value } return data } func buildHeaderIndex(headers []string) map[string]int { index := make(map[string]int, len(headers)) for idx, header := range headers { if header == "" { continue } normalized := normalizeHeader(header) if canonical, ok := headerCanonicalMap[normalized]; ok { index[canonical] = idx } else { index[normalized] = idx } } return index } func missingHeaders(headerIndex map[string]int, required []string) []string { var missing []string for _, key := range required { if _, ok := headerIndex[key]; !ok { missing = append(missing, key) } } return missing } func rowIsEmpty(row []string) bool { for _, value := range row { if strings.TrimSpace(value) != "" { return false } } return true } func validateRequiredRowFields(row []string, index map[string]int, required []string) error { var missing []string for _, key := range required { if strings.TrimSpace(getValue(row, index, key)) == "" { missing = append(missing, key) } } if len(missing) > 0 { return fmt.Errorf("missing required value(s): %s", strings.Join(missing, ", ")) } return nil } func getValue(row []string, index map[string]int, key string) string { if idx, ok := index[key]; ok && idx < len(row) { return strings.TrimSpace(row[idx]) } return "" } func parseOptionalBool(value string) (*bool, error) { if value == "" { return nil, nil } switch strings.ToLower(strings.TrimSpace(value)) { case "true", "1", "yes", "y": result := true return &result, nil case "false", "0", "no", "n": result := false return &result, nil default: return nil, fmt.Errorf("expected boolean value, got %q", value) } } func parseOptionalInt(value string) (*int, error) { if strings.TrimSpace(value) == "" { return nil, nil } intVal, err := strconv.Atoi(strings.TrimSpace(value)) if err != nil { return nil, err } return &intVal, nil } func parseIntList(raw string) ([]int, error) { if strings.TrimSpace(raw) == "" { return nil, nil } parts := splitTokens(raw) ints := make([]int, 0, len(parts)) for _, part := range parts { id, err := strconv.Atoi(part) if err != nil { return nil, fmt.Errorf("invalid integer %q", part) } ints = append(ints, id) } return ints, nil } func parseRoleTokens(raw string) []string { if strings.TrimSpace(raw) == "" { return nil } return splitTokens(raw) } func splitTokens(raw string) []string { cleaned := strings.NewReplacer(";", ",", "|", ",").Replace(raw) parts := strings.Split(cleaned, ",") result := make([]string, 0, len(parts)) for _, part := range parts { token := strings.TrimSpace(part) if token != "" { result = append(result, token) } } return result } func resolveRoleToken(token string, nameIndex map[string]api.Role) (api.Role, bool, string) { clean := strings.TrimSpace(token) if clean == "" { return api.Role{}, false, "empty role token" } normalized := normalizeRoleName(clean) if role, ok := nameIndex[normalized]; ok { return role, true, "" } return api.Role{}, false, fmt.Sprintf("role %q not found; use the role name exactly as listed", token) } func buildRoleNameIndex(roles []api.Role) map[string]api.Role { nameIndex := make(map[string]api.Role, len(roles)) for _, role := range roles { nameKey := normalizeRoleName(role.Name) if nameKey != "" { nameIndex[nameKey] = role } } return nameIndex } func normalizeHeader(value string) string { v := strings.TrimSpace(strings.ToLower(value)) v = strings.ReplaceAll(v, " ", "") v = strings.ReplaceAll(v, "_", "") v = strings.ReplaceAll(v, "-", "") return v } func normalizeRoleName(value string) string { return strings.ToLower(strings.TrimSpace(value)) } var headerCanonicalMap = map[string]string{ "username": "username", "user": "username", "firstname": "firstname", "first": "firstname", "lastname": "lastname", "last": "lastname", "email": "email", "mail": "email", "password": "password", "pass": "password", "companyid": "companyid", "company": "companyid", "locationid": "locationid", "location": "locationid", "userid": "userid", "id": "userid", "newusername": "new_username", "username_new": "new_username", "usernameupdate": "new_username", "phone": "phone", "phonenumber": "phone", "roles": "roles", "roleids": "roles", "rolenames": "roles", "status": "status", "timezone": "timezone", "details": "details", "issales": "issales", "sales": "issales", "managerid": "managerid", "manager": "managerid", "mfarequired": "mfarequired", "mfa": "mfarequired", "servicelineids": "servicelineids", } func stringPtr(value string) *string { v := value return &v } func intPtr(value int) *int { v := value return &v }