diff --git a/pySC/correction/tbt b/pySC/correction/tbt new file mode 100644 index 0000000..4f9923e --- /dev/null +++ b/pySC/correction/tbt @@ -0,0 +1,79 @@ +import numpy as np +import at + +def phase_advance_correction(ring, bpm_indices, elements_indices, dkick, cut, Px=None, Py=None, Etax=None): + """ + Perform phase advance and horizontal dispersion correction on the given ring. + + Parameters: + dkick: Change in quadrupole strength for response matrix calculation. + cut: number of kept singular values + Px, Py, Etax: response matrices for horizontal and vertical phase advances and dispersion. + If not provided, they will be calculated. + + Returns: + corrected ring + """ + # Initial Twiss parameters + _, _, twiss_err0 = at.get_optics(ring, bpm_indices) + elemdata0, beamdata, elemdata = at.get_optics(ring, bpm_indices) + mux0 = elemdata.mu[:, 0] / (2 * np.pi) + muy0 = elemdata.mu[:, 1] / (2 * np.pi) + Eta_x0 = elemdata.dispersion[:, 0] + + # Calculate Response Matrix if not provided + if Px is None or Py is None or Etax is None: + Px, Py, Etax = calculate_rm(dkick, ring, elements_indices, bpm_indices, mux0, muy0, Eta_x0) + + response_matrix = np.hstack((Px, Py, Etax)) + + elemdata0, beamdata, elemdata = at.get_optics(ring, bpm_indices) + mux = elemdata.mu[:, 0] / (2 * np.pi) + muy = elemdata.mu[:, 1] / (2 * np.pi) + Eta_xx = elemdata.dispersion[:, 0] + + measurement = np.concatenate((mux - mux0, muy - muy0, Eta_xx - Eta_x0), axis=0) + + s = np.linalg.svd(response_matrix.T, compute_uv=False) + system_solution = np.linalg.pinv(response_matrix.T, rcond=s[cut - 1] / s[0]) @ -measurement + ring = apply_correction(ring, system_solution, elements_indices) + + return ring + + +def calculate_rm(dkick, ring, elements_indices, bpm_indices, mux0, muy0, Eta_x0): + """ + Returns: + Px, Py, Etax: Response matrices for horizontal and vertical phase advances and dispersion. + """ + px =[] + py =[] + etax = [] + + for index in elements_indices: + original_setting = ring[index].PolynomB[1] + + ring[index].PolynomB[1] += dkick + _, _, elemdata = at.get_optics(ring, bpm_indices) + + mux = elemdata.mu[:, 0] / (2 * np.pi) + muy = elemdata.mu[:, 1] / (2 * np.pi) + Eta_x = elemdata.dispersion[:, 0] + + px.append((mux - mux0) / dkick) + py.append((muy - muy0) / dkick) + etax.append((Eta_x - Eta_x0) / dkick) + + ring[index].PolynomB[1] = original_setting + + Px = np.squeeze(np.array(px)) + Py = np.squeeze(np.array(py)) + Etax = np.squeeze(np.array(etax)) + + return Px, Py, Etax + + +def apply_correction(ring, corrections, elements_indices): + for i, index in enumerate(elements_indices): + ring[index].PolynomB[1] += corrections[i] + return ring