Looking at the following code
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"], data=np.array([[1, 2],[2.4, 5],[4.1, 7.1],[5, 8.9],[5.2, 10]]))
plot_df_points = pd.DataFrame(index = [1.5, 2, 3.7], columns = ["Series 1", "Series 2"], data=np.array([[1.2, 3.4],[4.5, 6.9],[5.5, 9.6]]))
df = pd.DataFrame(plot_df.stack()).reset_index()
df.columns = ["x", "Series","y"]
df_points = pd.DataFrame(plot_df_points.stack()).reset_index()
df_points.columns = ["x", "Series","y"]
fig, ax = plt.subplots()
sns.lineplot(data=df,x="x",y="y", hue="Series",ax=ax,palette="rocket",linewidth=2.5)
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax,s=200)
plt.show()
plt.close()
I would like to have the same colour / legend for the same series. I.e. the colour of series one in the line plot should be the same as in the scatter plot. How can this be achieved?
CodePudding user response:
You need to set the same palette in the scatterplot as in the lineplot. Or use the default in both cases (leaving out palette=).
To combine the legends, you can use the tuple legend handler (HandlerTuple).
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import seaborn as sns
import pandas as pd
import numpy as np
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"],
data=[[1, 2], [2.4, 5], [4.1, 7.1], [5, 8.9], [5.2, 10]])
plot_df_points = pd.DataFrame(index=[1.5, 2, 3.7], columns=["Series 1", "Series 2"],
data=[[1.2, 3.4], [4.5, 6.9], [5.5, 9.6]])
df = plot_df.rename_axis('x').reset_index().melt(id_vars='x', var_name='Series', value_name='y')
df_points = plot_df_points.rename_axis('x').reset_index().melt(id_vars='x', var_name='Series', value_name='y')
fig, ax = plt.subplots()
sns.lineplot(data=df, x="x", y="y", hue="Series", ax=ax, palette="rocket", linewidth=2.5)
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax, palette="rocket", s=200)
handles, labels = ax.get_legend_handles_labels()
ax.legend([tuple(handles[::2]), tuple(handles[1::2])], labels[:2], handlelength=3,
handler_map={tuple: HandlerTuple(ndivide=None)})
plt.tight_layout()
plt.show()

