"""
Scatter Correction Module for FTIR Spectral Preprocessing
==========================================================
Provides multiplicative scatter correction (MSC), extended MSC (EMSC),
and related methods for correcting light scattering effects.
**IMPORTANT**: This module expects absorbance data (AU), not transmittance (%).
Convert transmittance to absorbance first using convert_spectra() from trans_abs.py
Features:
- Single spectrum correction via scatter_correction()
- Batch DataFrame processing via apply_scatter_correction()
- Automatic column detection and sorting by wavenumber
- Performance optimized for large datasets (vectorized operations)
- Pandas and Polars DataFrame support
Logging:
This module uses Python's logging module for warnings and informational messages.
Configure the logger to control output:
import logging
logging.getLogger('utils.scatter_correction').setLevel(logging.INFO) # Show all messages
logging.getLogger('utils.scatter_correction').setLevel(logging.ERROR) # Only errors
Available Methods:
Run scatter_method_names() to see all available correction methods.
Common methods: msc, emsc, snv, snv_detrend
"""
from __future__ import annotations
from typing import Union, Tuple, Optional, List
import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
# Import shared spectral utilities
from .spectral_utils import (
_infer_spectral_columns,
_sort_spectral_columns
)
# Optional dependency: polars support is best-effort
try:
import polars as pl # type: ignore
except Exception:
pl = None # type: ignore
# Configure module logger
logger = logging.getLogger(__name__)
if not logger.handlers:
logger.setLevel(logging.WARNING)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s'))
logger.addHandler(handler)
[docs]
def scatter_correction(
intensities: Union[np.ndarray, list, tuple],
wavenumbers: Optional[Union[np.ndarray, list, tuple]] = None,
method: str = "msc",
reference: Optional[np.ndarray] = None,
**kwargs
) -> np.ndarray:
"""
Apply scatter correction to a single FTIR spectrum.
Parameters
----------
intensities : array-like
Raw intensity values (1-D). Absorbance data (AU), not transmittance (%).
wavenumbers : array-like, optional
X-axis values (wavenumbers in cm⁻¹). Ensures API consistency
with other preprocessing modules (baseline, denoise, atmospheric).
Not used in calculations but validates data integrity.
method : str, default "msc"
Correction method:
- 'msc': Multiplicative Scatter Correction
- 'emsc': Extended MSC (includes polynomial baseline terms)
- 'snv': Standard Normal Variate (per-spectrum normalization)
- 'snv_detrend': SNV followed by polynomial detrending
reference : np.ndarray, optional
Reference spectrum for MSC/EMSC. If None, cannot be applied
(use apply_scatter_correction for batch processing with automatic reference).
Must have same length as intensities.
**kwargs : method-specific parameters
emsc: poly_order (default 2)
snv_detrend: detrend_order (default 1)
Returns
-------
np.ndarray
Scatter-corrected intensity values. NaN values in input are preserved at their
original positions; correction is applied only to finite values.
Raises
------
ValueError
If method requires reference spectrum but none provided, or if
reference length doesn't match intensities.
Notes
-----
NaN Handling:
- If input contains NaN values, they are preserved in output
- Correction is computed only on finite values
- If all values are NaN, returns array of NaN
Methods requiring reference (msc, emsc):
- For single spectrum, reference must be provided explicitly
- For batch processing, use apply_scatter_correction() which computes
mean reference automatically
"""
y = np.asarray(intensities, dtype=np.float64)
if y.ndim != 1:
raise ValueError("`intensities` must be a 1-D array-like object.")
# Validate wavenumbers if provided
if wavenumbers is not None:
x = np.asarray(wavenumbers, dtype=np.float64)
if x.ndim != 1:
raise ValueError("`wavenumbers` must be a 1-D array-like object.")
if len(x) != len(y):
raise ValueError("`wavenumbers` and `intensities` must have the same length.")
# Check if wavenumbers are monotonic (good practice)
from .spectral_utils import _is_monotonic_strict
if not _is_monotonic_strict(x):
logger.warning(
"Wavenumbers are not strictly monotonic. Scatter correction assumes "
"uniform spacing or sorted data. Results may be unexpected."
)
# Handle NaN values: preserve positions but compute correction on finite values only
nan_mask = ~np.isfinite(y)
has_nans = np.any(nan_mask)
if has_nans:
# If all values are NaN, return NaN array
if np.all(nan_mask):
return np.full_like(y, np.nan)
# Store original array and extract finite values
y_original = y.copy()
y_finite = y[~nan_mask]
# Need at least 2 points for correction
if len(y_finite) < 2:
logger.warning(
f"Insufficient finite data points ({len(y_finite)}) for scatter correction. "
"Returning original spectrum with NaN preserved."
)
return y_original
# Also filter reference if provided
if reference is not None:
ref_finite = reference[~nan_mask]
else:
ref_finite = None
else:
y_finite = y
ref_finite = reference
# Validate reference for methods that need it
if method in ["msc", "emsc"]:
if reference is None:
raise ValueError(
f"Method '{method}' requires a reference spectrum. "
"Either provide 'reference' parameter, or use apply_scatter_correction() "
"for batch processing with automatic reference calculation."
)
if len(reference) != len(y):
raise ValueError(
f"Reference spectrum length ({len(reference)}) must match "
f"intensities length ({len(y)})."
)
# Apply correction to finite values
try:
if method == "msc":
corrected_finite = _msc_single(y_finite, ref_finite, **kwargs)
elif method == "emsc":
corrected_finite = _emsc_single(y_finite, ref_finite, **kwargs)
elif method == "snv":
corrected_finite = _snv_single(y_finite, **kwargs)
elif method == "snv_detrend":
corrected_finite = _snv_detrend_single(y_finite, **kwargs)
else:
raise ValueError(
f"Unknown method: '{method}'. "
"Valid options: msc, emsc, snv, snv_detrend"
)
except Exception as e:
if isinstance(e, ValueError) and "Unknown method" in str(e):
raise
raise RuntimeError(
f"Scatter correction failed for method '{method}'. "
f"Error: {type(e).__name__}: {str(e)}. "
f"Check parameter compatibility with method documentation."
) from e
# Restore NaN positions if needed
if has_nans:
result = np.full_like(y, np.nan)
result[~nan_mask] = corrected_finite
return result
else:
return corrected_finite
[docs]
def scatter_method_names() -> List[str]:
"""Return list of available scatter correction method names."""
return sorted(["msc", "emsc", "snv", "snv_detrend"])
# ---------------------------------------------------------------------------
# DATAFRAME-COMPATIBLE BATCH SCATTER CORRECTION
# ---------------------------------------------------------------------------
[docs]
def apply_scatter_correction(
data: Union[pd.DataFrame, "pl.DataFrame"],
method: str = "msc",
label_column: str = "label",
sample_id_column: str = "sample_id",
exclude_columns: Optional[List[str]] = None,
wn_min: Optional[float] = None,
wn_max: Optional[float] = None,
reference: Optional[np.ndarray] = None,
show_progress: bool = True,
**kwargs
) -> Union[pd.DataFrame, "pl.DataFrame"]:
"""
Apply scatter correction to a DataFrame of FTIR spectra (batch processing).
Works with both pandas and polars DataFrames. Each row is a sample,
numerical columns are wavenumbers. Applies scatter correction to all samples.
Parameters
----------
data : pd.DataFrame | pl.DataFrame
Wide-format DataFrame where rows = samples, columns = wavenumbers.
Should contain numerical columns with spectral data and optional
metadata columns (e.g., 'sample', 'label').
method : str, default "msc"
Scatter correction method. Options:
- 'msc': Multiplicative Scatter Correction
- 'emsc': Extended MSC (includes polynomial baseline terms)
- 'snv': Standard Normal Variate (per-spectrum normalization)
- 'snv_detrend': SNV followed by polynomial detrending
label_column : str, default "label"
Name of the label/group column to exclude from correction.
exclude_columns : list[str], optional
Additional column names to exclude from correction (e.g., 'sample', 'id').
wn_min : float, optional
Minimum wavenumber for column detection (default: 200.0 cm⁻¹).
Columns with wavenumbers below this value will be excluded.
wn_max : float, optional
Maximum wavenumber for column detection (default: 8000.0 cm⁻¹).
Columns with wavenumbers above this value will be excluded.
reference : np.ndarray, optional
Reference spectrum for MSC/EMSC. If None, uses mean of all spectra.
Must match the length of spectral columns.
show_progress : bool, default True
If True, display a progress bar during processing.
**kwargs : additional parameters
Method-specific parameters:
- emsc: poly_order (default 2)
- snv_detrend: detrend_order (default 1)
Returns
-------
pd.DataFrame | pl.DataFrame
Scatter-corrected DataFrame (same type as input) with spectral data
corrected and metadata columns preserved. Output columns are sorted
by ascending wavenumber for standardization.
NaN Handling
------------
Robustly handles NaN (missing) values in spectral data:
- NaN values are preserved in output at their original positions
- Correction is computed only on finite values
- If an entire spectrum is NaN, it remains as NaN
- For MSC/EMSC, reference spectrum is computed from finite values only
Performance
-----------
Optimized for large datasets using:
- Robust wavenumber column detection (parses column names, not dtype)
- Automatic column sorting to ensure monotonic wavenumber order
- Vectorized numpy array access (no DataFrame.loc overhead)
- Pre-allocated output arrays (no dynamic list appending)
- Progress tracking via tqdm
Examples
--------
>>> # Apply MSC scatter correction to all samples
>>> df_corrected = apply_scatter_correction(df_wide, method="msc")
>>> # Use EMSC with custom polynomial order
>>> df_corrected = apply_scatter_correction(
... df_wide,
... method="emsc",
... poly_order=3
... )
>>> # Use SNV (no reference needed)
>>> df_corrected = apply_scatter_correction(df_wide, method="snv")
>>> # Works with both pandas and polars
>>> df_pd_corrected = apply_scatter_correction(df_pandas)
>>> df_pl_corrected = apply_scatter_correction(df_polars)
>>> # Disable progress bar for cleaner output
>>> df_corrected = apply_scatter_correction(df_wide, show_progress=False)
"""
# Determine if input is polars or pandas
is_polars = (pl is not None) and isinstance(data, pl.DataFrame)
# Convert to pandas for processing
if is_polars:
df = data.to_pandas()
else:
df = data.copy()
# Prepare exclude_columns list
if exclude_columns is None:
exclude_columns = []
elif isinstance(exclude_columns, str):
exclude_columns = [exclude_columns]
else:
exclude_columns = list(exclude_columns)
# Always exclude the label column if it exists
if label_column in df.columns and label_column not in exclude_columns:
exclude_columns.append(label_column)
if sample_id_column in df.columns and sample_id_column not in exclude_columns:
exclude_columns.append(sample_id_column)
# Identify spectral columns by parsing column names as wavenumbers
numeric_cols, wavenumbers = _infer_spectral_columns(df, exclude_columns, wn_min, wn_max)
sorted_cols, sorted_wavenumbers, sort_idx = _sort_spectral_columns(numeric_cols, wavenumbers)
# Warn if columns will be reordered
if not np.array_equal(sort_idx, np.arange(len(sort_idx))):
logger.warning(
"Spectral columns are not in ascending wavenumber order. "
"Output DataFrame will have columns sorted by ascending wavenumber for standardization."
)
# OPTIMIZATION: Extract numpy array and pre-allocate result
spectral_data = df[sorted_cols].values.astype(np.float64)
n_samples = spectral_data.shape[0]
n_wavenumbers = spectral_data.shape[1]
# VALIDATION: Check if data appears to be transmittance instead of absorbance
sample_size = min(100, n_samples)
sample_data = spectral_data[:sample_size, :].flatten()
sample_data_finite = sample_data[np.isfinite(sample_data)]
if len(sample_data_finite) > 0:
median_val = np.median(sample_data_finite)
p95_val = np.percentile(sample_data_finite, 95)
if p95_val > 10.0 and median_val > 1.0:
raise ValueError(
f"Input data appears to be transmittance (%) rather than absorbance (AU). "
f"Detected: median={median_val:.2f}, 95th percentile={p95_val:.2f}. "
f"Scatter correction should be performed on absorbance for physical validity. "
f"Please convert your data first using: "
f"convert_spectra(data, mode='to_absorbance') from trans_abs.py"
)
# Compute reference spectrum if needed and not provided
if method in ["msc", "emsc"] and reference is None:
# Use mean of all finite values at each wavenumber
reference = np.nanmean(spectral_data, axis=0)
logger.info(f"Computed reference spectrum as mean of {n_samples} samples (NaN-aware).")
corrected_data = np.empty((n_samples, n_wavenumbers), dtype=np.float64)
# Apply scatter correction to each sample with progress bar
iterator = tqdm(
range(n_samples),
desc=f"Scatter correction ({method})",
disable=not show_progress,
dynamic_ncols=True
)
for i in iterator:
intensities = spectral_data[i, :]
# Apply scatter correction (pass reference for msc/emsc methods)
corrected_data[i, :] = scatter_correction(
intensities=intensities,
wavenumbers=sorted_wavenumbers,
method=method,
reference=reference,
**kwargs
)
# PHYSICAL CONSTRAINT VALIDATION: Check for negative absorbance after correction
finite_mask = np.isfinite(corrected_data)
if np.any(finite_mask):
n_negative = np.sum(corrected_data[finite_mask] < 0)
if n_negative > 0:
min_negative = np.min(corrected_data[finite_mask])
pct_negative = 100.0 * n_negative / np.sum(finite_mask)
logger.warning(
f"Scatter correction produced {n_negative} negative absorbance values "
f"({pct_negative:.1f}% of valid points, min={min_negative:.4f}). "
f"This is physically invalid. "
f"Recommendations: (1) Apply baseline correction before scatter correction, "
f"(2) Try different scatter correction method (e.g., 'snv' instead of 'msc'), "
f"or (3) Check that input data is absorbance, not transmittance."
)
# Reconstruct DataFrame with corrected spectral data
df_corrected_data = pd.DataFrame(
corrected_data,
index=df.index,
columns=sorted_cols
)
# Merge back with original metadata (columns not in sorted_cols)
metadata_cols = [c for c in df.columns if c not in sorted_cols]
if metadata_cols:
df_final = pd.concat([df[metadata_cols], df_corrected_data], axis=1)
else:
df_final = df_corrected_data
# Reorder columns to ensure metadata comes first
final_cols = metadata_cols + sorted_cols
df_final = df_final[final_cols]
# Convert back to polars if input was polars
if is_polars:
df_final = pl.from_pandas(df_final)
return df_final
# ---------------------------------------------------------------------------
# INDIVIDUAL METHODS
# ---------------------------------------------------------------------------
def _msc_single(
spectrum: np.ndarray,
reference: np.ndarray,
**kwargs
) -> np.ndarray:
"""
Multiplicative Scatter Correction (MSC) for a single spectrum.
Corrects for additive (baseline offset) and multiplicative
(path length) scatter effects by regressing the spectrum
against a reference spectrum.
Model: spectrum = a + b * reference
Corrected: (spectrum - a) / b
Parameters
----------
spectrum : np.ndarray
Single spectrum (1-D array).
reference : np.ndarray
Reference spectrum (same length as spectrum).
**kwargs : ignored
For API consistency.
Returns
-------
np.ndarray
MSC-corrected spectrum.
"""
# Fit linear regression: spectrum = a + b * reference
# Using least squares: [1, ref] @ [a, b]^T = spectrum
X = np.column_stack([np.ones_like(reference), reference])
coeffs = np.linalg.lstsq(X, spectrum, rcond=None)[0]
a, b = coeffs[0], coeffs[1]
# Avoid division by zero
if abs(b) < 1e-10:
b = 1.0
corrected = (spectrum - a) / b
return corrected
def _emsc_single(
spectrum: np.ndarray,
reference: np.ndarray,
poly_order: int = 2,
**kwargs
) -> np.ndarray:
"""
Extended Multiplicative Scatter Correction (EMSC) for a single spectrum.
Extends MSC by including polynomial baseline terms to handle
more complex scatter patterns.
Model: spectrum = a + b * reference + c1*x + c2*x² + ...
Parameters
----------
spectrum : np.ndarray
Single spectrum (1-D array).
reference : np.ndarray
Reference spectrum (same length as spectrum).
poly_order : int, default 2
Order of polynomial baseline terms.
**kwargs : ignored
For API consistency.
Returns
-------
np.ndarray
EMSC-corrected spectrum.
"""
n_points = len(spectrum)
# Create normalized x values for polynomial
x = np.linspace(-1, 1, n_points)
# Build design matrix: [1, reference, x, x², ...]
X = [np.ones(n_points), reference]
for p in range(1, poly_order + 1):
X.append(x ** p)
X = np.column_stack(X)
coeffs = np.linalg.lstsq(X, spectrum, rcond=None)[0]
# Reconstruct baseline (polynomial terms only)
baseline = coeffs[0] # intercept
for p in range(1, poly_order + 1):
baseline += coeffs[2 + p - 1] * (x ** p)
b = coeffs[1] # multiplicative term
if abs(b) < 1e-10:
b = 1.0
corrected = (spectrum - baseline) / b
return corrected
def _snv_single(
spectrum: np.ndarray,
**kwargs
) -> np.ndarray:
"""
Apply SNV (Standard Normal Variate) to a single spectrum.
SNV normalizes the spectrum to have mean=0 and std=1.
Parameters
----------
spectrum : np.ndarray
Single spectrum (1-D array).
**kwargs : ignored
For API consistency.
Returns
-------
np.ndarray
SNV-normalized spectrum.
"""
mean = np.mean(spectrum)
std = np.std(spectrum)
# Avoid division by zero
if std < 1e-10:
std = 1.0
return (spectrum - mean) / std
def _snv_detrend_single(
spectrum: np.ndarray,
detrend_order: int = 1,
**kwargs
) -> np.ndarray:
"""
Apply SNV followed by polynomial detrending to a single spectrum.
Removes residual baseline slope after SNV correction.
Parameters
----------
spectrum : np.ndarray
Single spectrum (1-D array).
detrend_order : int, default 1
Order of polynomial for detrending (1 = linear, 2 = quadratic, etc.).
**kwargs : ignored
For API consistency.
Returns
-------
np.ndarray
SNV-corrected and detrended spectrum.
"""
# First apply SNV
snv_spectrum = _snv_single(spectrum)
# Then detrend
n_points = len(snv_spectrum)
x = np.arange(n_points)
coeffs = np.polyfit(x, snv_spectrum, detrend_order)
trend = np.polyval(coeffs, x)
corrected = snv_spectrum - trend
return corrected
# ---------------------------------------------------------------------------
# HELPER FUNCTIONS (Deprecated)
# ---------------------------------------------------------------------------
[docs]
def msc_single(
spectrum: np.ndarray,
reference: np.ndarray
) -> Tuple[np.ndarray, float, float]:
"""
Apply MSC to a single spectrum and return coefficients.
**Deprecated**: Use scatter_correction() with method='msc' instead.
This function is retained for backward compatibility only.
Parameters
----------
spectrum : np.ndarray
Single spectrum.
reference : np.ndarray
Reference spectrum.
Returns
-------
corrected : np.ndarray
Corrected spectrum.
a : float
Offset coefficient.
b : float
Scaling coefficient.
"""
X = np.column_stack([np.ones_like(reference), reference])
coeffs = np.linalg.lstsq(X, spectrum, rcond=None)[0]
a, b = coeffs[0], coeffs[1]
if abs(b) < 1e-10:
b = 1.0
corrected = (spectrum - a) / b
return corrected, a, b